mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
增加特征重要性计算
This commit is contained in:
parent
4b75b7406d
commit
3b38727053
9
Makefile
9
Makefile
@ -126,10 +126,11 @@ train: requirements
|
||||
# INTERPRETABILITY #
|
||||
#################################################################################
|
||||
# 参数:
|
||||
# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic; 默认: delivery)
|
||||
# METHOD 方法 (ig, ablation, attention, all; 默认: all)
|
||||
# DATA 数据路径 (默认: data/processed/train.parquet)
|
||||
# MODEL 模型路径 (默认: models/model.pt)
|
||||
# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic, all; 默认: delivery)
|
||||
# 如果指定 'all',将依次计算所有 6 个任务
|
||||
# METHOD 方法 (ig, ablation, attention, all; 默认: ig)
|
||||
# DATA 数据路径 (默认: data/interim/internal.csv,即最终模型的全量训练数据)
|
||||
# MODEL 模型路径 (默认: models/final/model.pt)
|
||||
|
||||
TASK_FLAG = $(if $(TASK),--task $(TASK),)
|
||||
METHOD_FLAG = $(if $(METHOD),--method $(METHOD),)
|
||||
|
||||
@ -29,8 +29,8 @@ from torch.utils.data import DataLoader
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR, REPORTS_DIR
|
||||
from lnp_ml.dataset import LNPDataset, collate_fn
|
||||
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR, REPORTS_DIR
|
||||
from lnp_ml.dataset import LNPDataset, collate_fn, process_dataframe
|
||||
from lnp_ml.modeling.predict import load_model
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
|
||||
@ -332,12 +332,12 @@ def save_csv(
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Token-level Feature Importance")
|
||||
parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "model.pt"),
|
||||
parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "final" / "model.pt"),
|
||||
help="Path to trained model checkpoint")
|
||||
parser.add_argument("--data-path", type=str, default=str(PROCESSED_DATA_DIR / "train.parquet"),
|
||||
help="Path to data (parquet) for computing importance")
|
||||
parser.add_argument("--task", type=str, default="delivery", choices=TASKS,
|
||||
help="Target task for importance computation")
|
||||
parser.add_argument("--data-path", type=str, default=str(INTERIM_DATA_DIR / "internal.csv"),
|
||||
help="Path to data (.csv or .parquet) for computing importance")
|
||||
parser.add_argument("--task", type=str, default="all", choices=TASKS + ["all"],
|
||||
help="Target task for importance computation ('all' to run on all tasks)")
|
||||
parser.add_argument("--method", type=str, default="ig",
|
||||
choices=["ig", "ablation", "attention", "all"],
|
||||
help="Which method(s) to run")
|
||||
@ -363,7 +363,12 @@ def main() -> None:
|
||||
logger.info(f"Tokens ({len(token_names)}): {token_names}")
|
||||
|
||||
# ── Load data ──
|
||||
df = pd.read_parquet(args.data_path)
|
||||
data_path = Path(args.data_path)
|
||||
if data_path.suffix == ".csv":
|
||||
df = pd.read_csv(data_path)
|
||||
df = process_dataframe(df)
|
||||
else:
|
||||
df = pd.read_parquet(data_path)
|
||||
dataset = LNPDataset(df)
|
||||
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
||||
logger.info(f"Samples: {len(dataset)}")
|
||||
@ -382,24 +387,32 @@ def main() -> None:
|
||||
else [args.method]
|
||||
)
|
||||
|
||||
# Determine tasks to process
|
||||
tasks_to_run = TASKS if args.task == "all" else [args.task]
|
||||
|
||||
for task in tasks_to_run:
|
||||
logger.info(f"\n{'#'*60}")
|
||||
logger.info(f"# Processing task: {task}")
|
||||
logger.info(f"{'#'*60}")
|
||||
|
||||
results: Dict[str, np.ndarray] = {}
|
||||
|
||||
for method in methods:
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Computing: {method}")
|
||||
logger.info(f"Computing: {method} (task={task})")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
if method == "ig":
|
||||
imp = integrated_gradients_importance(
|
||||
model, all_tokens, device,
|
||||
task=args.task, batch_size=args.batch_size, n_steps=args.n_steps,
|
||||
task=task, batch_size=args.batch_size, n_steps=args.n_steps,
|
||||
)
|
||||
results["Integrated Gradients"] = imp
|
||||
|
||||
elif method == "ablation":
|
||||
imp = token_ablation_importance(
|
||||
model, all_tokens, device,
|
||||
task=args.task, batch_size=args.batch_size,
|
||||
task=task, batch_size=args.batch_size,
|
||||
)
|
||||
results["Token Ablation"] = imp
|
||||
|
||||
@ -413,7 +426,7 @@ def main() -> None:
|
||||
|
||||
# ── Print summary ──
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Token Importance Summary (task={args.task})")
|
||||
logger.info(f"Token Importance Summary (task={task})")
|
||||
logger.info(f"{'='*60}")
|
||||
for method_name, importance in results.items():
|
||||
normed = normalize(importance)
|
||||
@ -428,8 +441,8 @@ def main() -> None:
|
||||
|
||||
# ── Save results ──
|
||||
if results:
|
||||
plot_token_importance(results, token_names, args.task, out_dir)
|
||||
save_csv(results, token_names, args.task, out_dir, gate_vals=gv)
|
||||
plot_token_importance(results, token_names, task, out_dir)
|
||||
save_csv(results, token_names, task, out_dir, gate_vals=gv)
|
||||
|
||||
logger.info("\nDone!")
|
||||
|
||||
|
||||
9
reports/feature_importance/token_importance_biodist.csv
Normal file
9
reports/feature_importance/token_importance_biodist.csv
Normal file
@ -0,0 +1,9 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
desc,3.2013970842762367e-09,0.17906809140868585,0.5030443072319031
|
||||
mpnn,3.109777150387161e-09,0.17394338920379962,0.5024935007095337
|
||||
maccs,3.0657202874248063e-09,0.1714790968475434,0.5030479431152344
|
||||
morgan,3.020539287877718e-09,0.16895192663283606,0.5045571327209473
|
||||
help,1.937320535640997e-09,0.10836278088337024,0.49689680337905884
|
||||
comp,1.876732953403087e-09,0.10497385335304418,0.5007365345954895
|
||||
exp,1.6503372406931126e-09,0.09231055445232395,0.5002157688140869
|
||||
phys,1.6274562664095572e-11,0.0009103072183966515,0.49989768862724304
|
||||
|
BIN
reports/feature_importance/token_importance_biodist.png
Normal file
BIN
reports/feature_importance/token_importance_biodist.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
9
reports/feature_importance/token_importance_delivery.csv
Normal file
9
reports/feature_importance/token_importance_delivery.csv
Normal file
@ -0,0 +1,9 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
desc,0.48309860429978513,0.1830685579048154,0.5030443072319031
|
||||
mpnn,0.4800125818662255,0.1818991202961257,0.5024935007095337
|
||||
morgan,0.4681746999619525,0.17741319558101734,0.5045571327209473
|
||||
maccs,0.4642216987644718,0.1759152193455701,0.5030479431152344
|
||||
help,0.2636187040391749,0.09989740304701922,0.49689680337905884
|
||||
comp,0.2503036292270154,0.0948517011498054,0.5007365345954895
|
||||
exp,0.22730758172642437,0.08613742788142016,0.5002157688140869
|
||||
phys,0.0021569658208921167,0.0008173747942266034,0.49989768862724304
|
||||
|
BIN
reports/feature_importance/token_importance_delivery.png
Normal file
BIN
reports/feature_importance/token_importance_delivery.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 56 KiB |
9
reports/feature_importance/token_importance_ee.csv
Normal file
9
reports/feature_importance/token_importance_ee.csv
Normal file
@ -0,0 +1,9 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
desc,0.28693426746338735,0.2025623449841757,0.5030443072319031
|
||||
mpnn,0.2846832493908164,0.2009732301551499,0.5024935007095337
|
||||
morgan,0.2646961202937205,0.18686323982460912,0.5045571327209473
|
||||
maccs,0.25845640337994125,0.1824582877731647,0.5030479431152344
|
||||
help,0.11619691669315178,0.082029658337337,0.49689680337905884
|
||||
comp,0.10812802919624785,0.07633339630758515,0.5007365345954895
|
||||
exp,0.09640860187977682,0.06806002171178546,0.5002157688140869
|
||||
phys,0.001019643036021533,0.000719820906192951,0.49989768862724304
|
||||
|
BIN
reports/feature_importance/token_importance_ee.png
Normal file
BIN
reports/feature_importance/token_importance_ee.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 54 KiB |
9
reports/feature_importance/token_importance_pdi.csv
Normal file
9
reports/feature_importance/token_importance_pdi.csv
Normal file
@ -0,0 +1,9 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
mpnn,0.29288250773521896,0.2140725741095052,0.5024935007095337
|
||||
desc,0.2764989421166159,0.2020975603328644,0.5030443072319031
|
||||
morgan,0.2578500855439716,0.18846680866532511,0.5045571327209473
|
||||
maccs,0.2471777274240024,0.18066620905892686,0.5030479431152344
|
||||
help,0.10363027887841153,0.07574505123823679,0.49689680337905884
|
||||
comp,0.09781463069792998,0.07149429968008479,0.5007365345954895
|
||||
exp,0.091429982122356,0.06682765650659303,0.5002157688140869
|
||||
phys,0.0008617135523839594,0.0006298404084638948,0.49989768862724304
|
||||
|
BIN
reports/feature_importance/token_importance_pdi.png
Normal file
BIN
reports/feature_importance/token_importance_pdi.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 54 KiB |
9
reports/feature_importance/token_importance_size.csv
Normal file
9
reports/feature_importance/token_importance_size.csv
Normal file
@ -0,0 +1,9 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
desc,0.43124767207505,0.1973483310882186,0.5030443072319031
|
||||
mpnn,0.4250195774970647,0.1944982193069518,0.5024935007095337
|
||||
morgan,0.41677666295297405,0.19072608200879151,0.5045571327209473
|
||||
maccs,0.40606126033119555,0.1858224802938626,0.5030479431152344
|
||||
help,0.17471509961619794,0.07995343640757792,0.49689680337905884
|
||||
comp,0.16962298878652332,0.07762317554120124,0.5007365345954895
|
||||
exp,0.16035562746707094,0.07338222907722322,0.5002157688140869
|
||||
phys,0.0014117471939899774,0.0006460462761730848,0.49989768862724304
|
||||
|
BIN
reports/feature_importance/token_importance_size.png
Normal file
BIN
reports/feature_importance/token_importance_size.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
9
reports/feature_importance/token_importance_toxic.csv
Normal file
9
reports/feature_importance/token_importance_toxic.csv
Normal file
@ -0,0 +1,9 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
mpnn,0.17711989597355013,0.21006733816477163,0.5024935007095337
|
||||
desc,0.17162019114321028,0.20354459068882486,0.5030443072319031
|
||||
morgan,0.16095514462538474,0.19089565635488848,0.5045571327209473
|
||||
maccs,0.15742516053846328,0.18670903261717847,0.5030479431152344
|
||||
comp,0.05954181250470194,0.07061764571178654,0.5007365345954895
|
||||
help,0.05892528920918569,0.06988643814815398,0.49689680337905884
|
||||
exp,0.05708834082906645,0.06770778478774685,0.5002157688140869
|
||||
phys,0.0004818760368551862,0.0005715135266490605,0.49989768862724304
|
||||
|
BIN
reports/feature_importance/token_importance_toxic.png
Normal file
BIN
reports/feature_importance/token_importance_toxic.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
Loading…
x
Reference in New Issue
Block a user