增加特征重要性计算

This commit is contained in:
RYDE-WORK 2026-03-03 13:45:46 +08:00
parent 4b75b7406d
commit 3b38727053
14 changed files with 123 additions and 55 deletions

View File

@ -126,10 +126,11 @@ train: requirements
# INTERPRETABILITY #
#################################################################################
# 参数:
# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic; 默认: delivery)
# METHOD 方法 (ig, ablation, attention, all; 默认: all)
# DATA 数据路径 (默认: data/processed/train.parquet)
# MODEL 模型路径 (默认: models/model.pt)
# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic, all; 默认: delivery)
# 如果指定 'all',将依次计算所有 6 个任务
# METHOD 方法 (ig, ablation, attention, all; 默认: ig)
# DATA 数据路径 (默认: data/interim/internal.csv即最终模型的全量训练数据)
# MODEL 模型路径 (默认: models/final/model.pt)
TASK_FLAG = $(if $(TASK),--task $(TASK),)
METHOD_FLAG = $(if $(METHOD),--method $(METHOD),)

View File

@ -29,8 +29,8 @@ from torch.utils.data import DataLoader
from loguru import logger
from tqdm import tqdm
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR, REPORTS_DIR
from lnp_ml.dataset import LNPDataset, collate_fn
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR, REPORTS_DIR
from lnp_ml.dataset import LNPDataset, collate_fn, process_dataframe
from lnp_ml.modeling.predict import load_model
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
@ -332,12 +332,12 @@ def save_csv(
def main() -> None:
parser = argparse.ArgumentParser(description="Token-level Feature Importance")
parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "model.pt"),
parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "final" / "model.pt"),
help="Path to trained model checkpoint")
parser.add_argument("--data-path", type=str, default=str(PROCESSED_DATA_DIR / "train.parquet"),
help="Path to data (parquet) for computing importance")
parser.add_argument("--task", type=str, default="delivery", choices=TASKS,
help="Target task for importance computation")
parser.add_argument("--data-path", type=str, default=str(INTERIM_DATA_DIR / "internal.csv"),
help="Path to data (.csv or .parquet) for computing importance")
parser.add_argument("--task", type=str, default="all", choices=TASKS + ["all"],
help="Target task for importance computation ('all' to run on all tasks)")
parser.add_argument("--method", type=str, default="ig",
choices=["ig", "ablation", "attention", "all"],
help="Which method(s) to run")
@ -363,7 +363,12 @@ def main() -> None:
logger.info(f"Tokens ({len(token_names)}): {token_names}")
# ── Load data ──
df = pd.read_parquet(args.data_path)
data_path = Path(args.data_path)
if data_path.suffix == ".csv":
df = pd.read_csv(data_path)
df = process_dataframe(df)
else:
df = pd.read_parquet(data_path)
dataset = LNPDataset(df)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
logger.info(f"Samples: {len(dataset)}")
@ -382,24 +387,32 @@ def main() -> None:
else [args.method]
)
# Determine tasks to process
tasks_to_run = TASKS if args.task == "all" else [args.task]
for task in tasks_to_run:
logger.info(f"\n{'#'*60}")
logger.info(f"# Processing task: {task}")
logger.info(f"{'#'*60}")
results: Dict[str, np.ndarray] = {}
for method in methods:
logger.info(f"\n{'='*60}")
logger.info(f"Computing: {method}")
logger.info(f"Computing: {method} (task={task})")
logger.info(f"{'='*60}")
if method == "ig":
imp = integrated_gradients_importance(
model, all_tokens, device,
task=args.task, batch_size=args.batch_size, n_steps=args.n_steps,
task=task, batch_size=args.batch_size, n_steps=args.n_steps,
)
results["Integrated Gradients"] = imp
elif method == "ablation":
imp = token_ablation_importance(
model, all_tokens, device,
task=args.task, batch_size=args.batch_size,
task=task, batch_size=args.batch_size,
)
results["Token Ablation"] = imp
@ -413,7 +426,7 @@ def main() -> None:
# ── Print summary ──
logger.info(f"\n{'='*60}")
logger.info(f"Token Importance Summary (task={args.task})")
logger.info(f"Token Importance Summary (task={task})")
logger.info(f"{'='*60}")
for method_name, importance in results.items():
normed = normalize(importance)
@ -428,8 +441,8 @@ def main() -> None:
# ── Save results ──
if results:
plot_token_importance(results, token_names, args.task, out_dir)
save_csv(results, token_names, args.task, out_dir, gate_vals=gv)
plot_token_importance(results, token_names, task, out_dir)
save_csv(results, token_names, task, out_dir, gate_vals=gv)
logger.info("\nDone!")

View File

@ -0,0 +1,9 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,3.2013970842762367e-09,0.17906809140868585,0.5030443072319031
mpnn,3.109777150387161e-09,0.17394338920379962,0.5024935007095337
maccs,3.0657202874248063e-09,0.1714790968475434,0.5030479431152344
morgan,3.020539287877718e-09,0.16895192663283606,0.5045571327209473
help,1.937320535640997e-09,0.10836278088337024,0.49689680337905884
comp,1.876732953403087e-09,0.10497385335304418,0.5007365345954895
exp,1.6503372406931126e-09,0.09231055445232395,0.5002157688140869
phys,1.6274562664095572e-11,0.0009103072183966515,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 3.2013970842762367e-09 0.17906809140868585 0.5030443072319031
3 mpnn 3.109777150387161e-09 0.17394338920379962 0.5024935007095337
4 maccs 3.0657202874248063e-09 0.1714790968475434 0.5030479431152344
5 morgan 3.020539287877718e-09 0.16895192663283606 0.5045571327209473
6 help 1.937320535640997e-09 0.10836278088337024 0.49689680337905884
7 comp 1.876732953403087e-09 0.10497385335304418 0.5007365345954895
8 exp 1.6503372406931126e-09 0.09231055445232395 0.5002157688140869
9 phys 1.6274562664095572e-11 0.0009103072183966515 0.49989768862724304

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

View File

@ -0,0 +1,9 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,0.48309860429978513,0.1830685579048154,0.5030443072319031
mpnn,0.4800125818662255,0.1818991202961257,0.5024935007095337
morgan,0.4681746999619525,0.17741319558101734,0.5045571327209473
maccs,0.4642216987644718,0.1759152193455701,0.5030479431152344
help,0.2636187040391749,0.09989740304701922,0.49689680337905884
comp,0.2503036292270154,0.0948517011498054,0.5007365345954895
exp,0.22730758172642437,0.08613742788142016,0.5002157688140869
phys,0.0021569658208921167,0.0008173747942266034,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 0.48309860429978513 0.1830685579048154 0.5030443072319031
3 mpnn 0.4800125818662255 0.1818991202961257 0.5024935007095337
4 morgan 0.4681746999619525 0.17741319558101734 0.5045571327209473
5 maccs 0.4642216987644718 0.1759152193455701 0.5030479431152344
6 help 0.2636187040391749 0.09989740304701922 0.49689680337905884
7 comp 0.2503036292270154 0.0948517011498054 0.5007365345954895
8 exp 0.22730758172642437 0.08613742788142016 0.5002157688140869
9 phys 0.0021569658208921167 0.0008173747942266034 0.49989768862724304

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

View File

@ -0,0 +1,9 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,0.28693426746338735,0.2025623449841757,0.5030443072319031
mpnn,0.2846832493908164,0.2009732301551499,0.5024935007095337
morgan,0.2646961202937205,0.18686323982460912,0.5045571327209473
maccs,0.25845640337994125,0.1824582877731647,0.5030479431152344
help,0.11619691669315178,0.082029658337337,0.49689680337905884
comp,0.10812802919624785,0.07633339630758515,0.5007365345954895
exp,0.09640860187977682,0.06806002171178546,0.5002157688140869
phys,0.001019643036021533,0.000719820906192951,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 0.28693426746338735 0.2025623449841757 0.5030443072319031
3 mpnn 0.2846832493908164 0.2009732301551499 0.5024935007095337
4 morgan 0.2646961202937205 0.18686323982460912 0.5045571327209473
5 maccs 0.25845640337994125 0.1824582877731647 0.5030479431152344
6 help 0.11619691669315178 0.082029658337337 0.49689680337905884
7 comp 0.10812802919624785 0.07633339630758515 0.5007365345954895
8 exp 0.09640860187977682 0.06806002171178546 0.5002157688140869
9 phys 0.001019643036021533 0.000719820906192951 0.49989768862724304

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

View File

@ -0,0 +1,9 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
mpnn,0.29288250773521896,0.2140725741095052,0.5024935007095337
desc,0.2764989421166159,0.2020975603328644,0.5030443072319031
morgan,0.2578500855439716,0.18846680866532511,0.5045571327209473
maccs,0.2471777274240024,0.18066620905892686,0.5030479431152344
help,0.10363027887841153,0.07574505123823679,0.49689680337905884
comp,0.09781463069792998,0.07149429968008479,0.5007365345954895
exp,0.091429982122356,0.06682765650659303,0.5002157688140869
phys,0.0008617135523839594,0.0006298404084638948,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 mpnn 0.29288250773521896 0.2140725741095052 0.5024935007095337
3 desc 0.2764989421166159 0.2020975603328644 0.5030443072319031
4 morgan 0.2578500855439716 0.18846680866532511 0.5045571327209473
5 maccs 0.2471777274240024 0.18066620905892686 0.5030479431152344
6 help 0.10363027887841153 0.07574505123823679 0.49689680337905884
7 comp 0.09781463069792998 0.07149429968008479 0.5007365345954895
8 exp 0.091429982122356 0.06682765650659303 0.5002157688140869
9 phys 0.0008617135523839594 0.0006298404084638948 0.49989768862724304

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

View File

@ -0,0 +1,9 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,0.43124767207505,0.1973483310882186,0.5030443072319031
mpnn,0.4250195774970647,0.1944982193069518,0.5024935007095337
morgan,0.41677666295297405,0.19072608200879151,0.5045571327209473
maccs,0.40606126033119555,0.1858224802938626,0.5030479431152344
help,0.17471509961619794,0.07995343640757792,0.49689680337905884
comp,0.16962298878652332,0.07762317554120124,0.5007365345954895
exp,0.16035562746707094,0.07338222907722322,0.5002157688140869
phys,0.0014117471939899774,0.0006460462761730848,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 0.43124767207505 0.1973483310882186 0.5030443072319031
3 mpnn 0.4250195774970647 0.1944982193069518 0.5024935007095337
4 morgan 0.41677666295297405 0.19072608200879151 0.5045571327209473
5 maccs 0.40606126033119555 0.1858224802938626 0.5030479431152344
6 help 0.17471509961619794 0.07995343640757792 0.49689680337905884
7 comp 0.16962298878652332 0.07762317554120124 0.5007365345954895
8 exp 0.16035562746707094 0.07338222907722322 0.5002157688140869
9 phys 0.0014117471939899774 0.0006460462761730848 0.49989768862724304

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

View File

@ -0,0 +1,9 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
mpnn,0.17711989597355013,0.21006733816477163,0.5024935007095337
desc,0.17162019114321028,0.20354459068882486,0.5030443072319031
morgan,0.16095514462538474,0.19089565635488848,0.5045571327209473
maccs,0.15742516053846328,0.18670903261717847,0.5030479431152344
comp,0.05954181250470194,0.07061764571178654,0.5007365345954895
help,0.05892528920918569,0.06988643814815398,0.49689680337905884
exp,0.05708834082906645,0.06770778478774685,0.5002157688140869
phys,0.0004818760368551862,0.0005715135266490605,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 mpnn 0.17711989597355013 0.21006733816477163 0.5024935007095337
3 desc 0.17162019114321028 0.20354459068882486 0.5030443072319031
4 morgan 0.16095514462538474 0.19089565635488848 0.5045571327209473
5 maccs 0.15742516053846328 0.18670903261717847 0.5030479431152344
6 comp 0.05954181250470194 0.07061764571178654 0.5007365345954895
7 help 0.05892528920918569 0.06988643814815398 0.49689680337905884
8 exp 0.05708834082906645 0.06770778478774685 0.5002157688140869
9 phys 0.0004818760368551862 0.0005715135266490605 0.49989768862724304

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB