diff --git a/Makefile b/Makefile index edd9f9e..169fdb6 100644 --- a/Makefile +++ b/Makefile @@ -126,10 +126,11 @@ train: requirements # INTERPRETABILITY # ################################################################################# # 参数: -# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic; 默认: delivery) -# METHOD 方法 (ig, ablation, attention, all; 默认: all) -# DATA 数据路径 (默认: data/processed/train.parquet) -# MODEL 模型路径 (默认: models/model.pt) +# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic, all; 默认: delivery) +# 如果指定 'all',将依次计算所有 6 个任务 +# METHOD 方法 (ig, ablation, attention, all; 默认: ig) +# DATA 数据路径 (默认: data/interim/internal.csv,即最终模型的全量训练数据) +# MODEL 模型路径 (默认: models/final/model.pt) TASK_FLAG = $(if $(TASK),--task $(TASK),) METHOD_FLAG = $(if $(METHOD),--method $(METHOD),) diff --git a/lnp_ml/interpretability/token_importance.py b/lnp_ml/interpretability/token_importance.py index 6a71963..40e597a 100644 --- a/lnp_ml/interpretability/token_importance.py +++ b/lnp_ml/interpretability/token_importance.py @@ -29,8 +29,8 @@ from torch.utils.data import DataLoader from loguru import logger from tqdm import tqdm -from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR, REPORTS_DIR -from lnp_ml.dataset import LNPDataset, collate_fn +from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR, REPORTS_DIR +from lnp_ml.dataset import LNPDataset, collate_fn, process_dataframe from lnp_ml.modeling.predict import load_model from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN @@ -332,12 +332,12 @@ def save_csv( def main() -> None: parser = argparse.ArgumentParser(description="Token-level Feature Importance") - parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "model.pt"), + parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "final" / "model.pt"), help="Path to trained model checkpoint") - parser.add_argument("--data-path", type=str, default=str(PROCESSED_DATA_DIR / "train.parquet"), - help="Path to data (parquet) for computing importance") - parser.add_argument("--task", type=str, default="delivery", choices=TASKS, - help="Target task for importance computation") + parser.add_argument("--data-path", type=str, default=str(INTERIM_DATA_DIR / "internal.csv"), + help="Path to data (.csv or .parquet) for computing importance") + parser.add_argument("--task", type=str, default="all", choices=TASKS + ["all"], + help="Target task for importance computation ('all' to run on all tasks)") parser.add_argument("--method", type=str, default="ig", choices=["ig", "ablation", "attention", "all"], help="Which method(s) to run") @@ -363,7 +363,12 @@ def main() -> None: logger.info(f"Tokens ({len(token_names)}): {token_names}") # ── Load data ── - df = pd.read_parquet(args.data_path) + data_path = Path(args.data_path) + if data_path.suffix == ".csv": + df = pd.read_csv(data_path) + df = process_dataframe(df) + else: + df = pd.read_parquet(data_path) dataset = LNPDataset(df) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) logger.info(f"Samples: {len(dataset)}") @@ -382,54 +387,62 @@ def main() -> None: else [args.method] ) - results: Dict[str, np.ndarray] = {} + # Determine tasks to process + tasks_to_run = TASKS if args.task == "all" else [args.task] - for method in methods: + for task in tasks_to_run: + logger.info(f"\n{'#'*60}") + logger.info(f"# Processing task: {task}") + logger.info(f"{'#'*60}") + + results: Dict[str, np.ndarray] = {} + + for method in methods: + logger.info(f"\n{'='*60}") + logger.info(f"Computing: {method} (task={task})") + logger.info(f"{'='*60}") + + if method == "ig": + imp = integrated_gradients_importance( + model, all_tokens, device, + task=task, batch_size=args.batch_size, n_steps=args.n_steps, + ) + results["Integrated Gradients"] = imp + + elif method == "ablation": + imp = token_ablation_importance( + model, all_tokens, device, + task=task, batch_size=args.batch_size, + ) + results["Token Ablation"] = imp + + elif method == "attention": + imp = fusion_attention_importance( + model, all_tokens, device, + batch_size=args.batch_size, + ) + if imp is not None: + results["Fusion Attention"] = imp + + # ── Print summary ── logger.info(f"\n{'='*60}") - logger.info(f"Computing: {method}") + logger.info(f"Token Importance Summary (task={task})") logger.info(f"{'='*60}") + for method_name, importance in results.items(): + normed = normalize(importance) + order = np.argsort(-normed) + logger.info(f"\n {method_name}:") + for rank, idx in enumerate(order, 1): + logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}") - if method == "ig": - imp = integrated_gradients_importance( - model, all_tokens, device, - task=args.task, batch_size=args.batch_size, n_steps=args.n_steps, - ) - results["Integrated Gradients"] = imp + logger.info(f"\n Gate values (sigmoid):") + for name, val in sorted(gv.items(), key=lambda x: -x[1]): + logger.info(f" {name:<10s} {val:.4f}") - elif method == "ablation": - imp = token_ablation_importance( - model, all_tokens, device, - task=args.task, batch_size=args.batch_size, - ) - results["Token Ablation"] = imp - - elif method == "attention": - imp = fusion_attention_importance( - model, all_tokens, device, - batch_size=args.batch_size, - ) - if imp is not None: - results["Fusion Attention"] = imp - - # ── Print summary ── - logger.info(f"\n{'='*60}") - logger.info(f"Token Importance Summary (task={args.task})") - logger.info(f"{'='*60}") - for method_name, importance in results.items(): - normed = normalize(importance) - order = np.argsort(-normed) - logger.info(f"\n {method_name}:") - for rank, idx in enumerate(order, 1): - logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}") - - logger.info(f"\n Gate values (sigmoid):") - for name, val in sorted(gv.items(), key=lambda x: -x[1]): - logger.info(f" {name:<10s} {val:.4f}") - - # ── Save results ── - if results: - plot_token_importance(results, token_names, args.task, out_dir) - save_csv(results, token_names, args.task, out_dir, gate_vals=gv) + # ── Save results ── + if results: + plot_token_importance(results, token_names, task, out_dir) + save_csv(results, token_names, task, out_dir, gate_vals=gv) logger.info("\nDone!") diff --git a/reports/feature_importance/token_importance_biodist.csv b/reports/feature_importance/token_importance_biodist.csv new file mode 100644 index 0000000..72d819f --- /dev/null +++ b/reports/feature_importance/token_importance_biodist.csv @@ -0,0 +1,9 @@ +token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid +desc,3.2013970842762367e-09,0.17906809140868585,0.5030443072319031 +mpnn,3.109777150387161e-09,0.17394338920379962,0.5024935007095337 +maccs,3.0657202874248063e-09,0.1714790968475434,0.5030479431152344 +morgan,3.020539287877718e-09,0.16895192663283606,0.5045571327209473 +help,1.937320535640997e-09,0.10836278088337024,0.49689680337905884 +comp,1.876732953403087e-09,0.10497385335304418,0.5007365345954895 +exp,1.6503372406931126e-09,0.09231055445232395,0.5002157688140869 +phys,1.6274562664095572e-11,0.0009103072183966515,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_biodist.png b/reports/feature_importance/token_importance_biodist.png new file mode 100644 index 0000000..8e316a5 Binary files /dev/null and b/reports/feature_importance/token_importance_biodist.png differ diff --git a/reports/feature_importance/token_importance_delivery.csv b/reports/feature_importance/token_importance_delivery.csv new file mode 100644 index 0000000..c2aa6ec --- /dev/null +++ b/reports/feature_importance/token_importance_delivery.csv @@ -0,0 +1,9 @@ +token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid +desc,0.48309860429978513,0.1830685579048154,0.5030443072319031 +mpnn,0.4800125818662255,0.1818991202961257,0.5024935007095337 +morgan,0.4681746999619525,0.17741319558101734,0.5045571327209473 +maccs,0.4642216987644718,0.1759152193455701,0.5030479431152344 +help,0.2636187040391749,0.09989740304701922,0.49689680337905884 +comp,0.2503036292270154,0.0948517011498054,0.5007365345954895 +exp,0.22730758172642437,0.08613742788142016,0.5002157688140869 +phys,0.0021569658208921167,0.0008173747942266034,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_delivery.png b/reports/feature_importance/token_importance_delivery.png new file mode 100644 index 0000000..f01e9bb Binary files /dev/null and b/reports/feature_importance/token_importance_delivery.png differ diff --git a/reports/feature_importance/token_importance_ee.csv b/reports/feature_importance/token_importance_ee.csv new file mode 100644 index 0000000..efd3a51 --- /dev/null +++ b/reports/feature_importance/token_importance_ee.csv @@ -0,0 +1,9 @@ +token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid +desc,0.28693426746338735,0.2025623449841757,0.5030443072319031 +mpnn,0.2846832493908164,0.2009732301551499,0.5024935007095337 +morgan,0.2646961202937205,0.18686323982460912,0.5045571327209473 +maccs,0.25845640337994125,0.1824582877731647,0.5030479431152344 +help,0.11619691669315178,0.082029658337337,0.49689680337905884 +comp,0.10812802919624785,0.07633339630758515,0.5007365345954895 +exp,0.09640860187977682,0.06806002171178546,0.5002157688140869 +phys,0.001019643036021533,0.000719820906192951,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_ee.png b/reports/feature_importance/token_importance_ee.png new file mode 100644 index 0000000..43b9096 Binary files /dev/null and b/reports/feature_importance/token_importance_ee.png differ diff --git a/reports/feature_importance/token_importance_pdi.csv b/reports/feature_importance/token_importance_pdi.csv new file mode 100644 index 0000000..30d6c03 --- /dev/null +++ b/reports/feature_importance/token_importance_pdi.csv @@ -0,0 +1,9 @@ +token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid +mpnn,0.29288250773521896,0.2140725741095052,0.5024935007095337 +desc,0.2764989421166159,0.2020975603328644,0.5030443072319031 +morgan,0.2578500855439716,0.18846680866532511,0.5045571327209473 +maccs,0.2471777274240024,0.18066620905892686,0.5030479431152344 +help,0.10363027887841153,0.07574505123823679,0.49689680337905884 +comp,0.09781463069792998,0.07149429968008479,0.5007365345954895 +exp,0.091429982122356,0.06682765650659303,0.5002157688140869 +phys,0.0008617135523839594,0.0006298404084638948,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_pdi.png b/reports/feature_importance/token_importance_pdi.png new file mode 100644 index 0000000..d683bf3 Binary files /dev/null and b/reports/feature_importance/token_importance_pdi.png differ diff --git a/reports/feature_importance/token_importance_size.csv b/reports/feature_importance/token_importance_size.csv new file mode 100644 index 0000000..786b3cf --- /dev/null +++ b/reports/feature_importance/token_importance_size.csv @@ -0,0 +1,9 @@ +token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid +desc,0.43124767207505,0.1973483310882186,0.5030443072319031 +mpnn,0.4250195774970647,0.1944982193069518,0.5024935007095337 +morgan,0.41677666295297405,0.19072608200879151,0.5045571327209473 +maccs,0.40606126033119555,0.1858224802938626,0.5030479431152344 +help,0.17471509961619794,0.07995343640757792,0.49689680337905884 +comp,0.16962298878652332,0.07762317554120124,0.5007365345954895 +exp,0.16035562746707094,0.07338222907722322,0.5002157688140869 +phys,0.0014117471939899774,0.0006460462761730848,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_size.png b/reports/feature_importance/token_importance_size.png new file mode 100644 index 0000000..bf96cab Binary files /dev/null and b/reports/feature_importance/token_importance_size.png differ diff --git a/reports/feature_importance/token_importance_toxic.csv b/reports/feature_importance/token_importance_toxic.csv new file mode 100644 index 0000000..85912fb --- /dev/null +++ b/reports/feature_importance/token_importance_toxic.csv @@ -0,0 +1,9 @@ +token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid +mpnn,0.17711989597355013,0.21006733816477163,0.5024935007095337 +desc,0.17162019114321028,0.20354459068882486,0.5030443072319031 +morgan,0.16095514462538474,0.19089565635488848,0.5045571327209473 +maccs,0.15742516053846328,0.18670903261717847,0.5030479431152344 +comp,0.05954181250470194,0.07061764571178654,0.5007365345954895 +help,0.05892528920918569,0.06988643814815398,0.49689680337905884 +exp,0.05708834082906645,0.06770778478774685,0.5002157688140869 +phys,0.0004818760368551862,0.0005715135266490605,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_toxic.png b/reports/feature_importance/token_importance_toxic.png new file mode 100644 index 0000000..78835a2 Binary files /dev/null and b/reports/feature_importance/token_importance_toxic.png differ