diff --git a/Makefile b/Makefile index 169fdb6..0664e7a 100644 --- a/Makefile +++ b/Makefile @@ -123,25 +123,36 @@ train: requirements $(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG) ################################################################################# -# INTERPRETABILITY # +# INTERPRETABILITY (biodistribution feature importance) # ################################################################################# # 参数: -# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic, all; 默认: delivery) -# 如果指定 'all',将依次计算所有 6 个任务 -# METHOD 方法 (ig, ablation, attention, all; 默认: ig) -# DATA 数据路径 (默认: data/interim/internal.csv,即最终模型的全量训练数据) +# ORGAN 器官 (lymph_nodes, heart, liver, spleen, lung, kidney, muscle, all; 默认: all) +# METHOD token 级方法 (ig, ablation, attention, all; 默认: ig) +# DESC_IG 同时计算 desc 内部特征 IG (1 启用; 默认不启用) +# DESC_TOP_K 可视化展示的 top-K 特征数 (默认: 30) +# DATA 数据路径 (默认: data/interim/internal.csv) # MODEL 模型路径 (默认: models/final/model.pt) -TASK_FLAG = $(if $(TASK),--task $(TASK),) METHOD_FLAG = $(if $(METHOD),--method $(METHOD),) DATA_FLAG = $(if $(DATA),--data-path $(DATA),) MODEL_FLAG = $(if $(MODEL),--model-path $(MODEL),) +ORGAN_FLAG = $(if $(ORGAN),--organ $(ORGAN),) +DESC_IG_FLAG = $(if $(DESC_IG),--desc-ig,) +DESC_TOP_K_FLAG = $(if $(DESC_TOP_K),--desc-top-k $(DESC_TOP_K),) -## Compute token-level feature importance +## Compute biodistribution feature importance (token-level) .PHONY: feature_importance feature_importance: requirements $(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \ - $(TASK_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG) + $(ORGAN_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG) \ + $(DESC_IG_FLAG) $(DESC_TOP_K_FLAG) + +## Compute biodistribution feature importance (token + descriptor-level) +.PHONY: desc_importance +desc_importance: requirements + $(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \ + --desc-ig $(ORGAN_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) \ + $(DEVICE_FLAG) $(DESC_TOP_K_FLAG) ################################################################################# # SERVING & DEPLOYMENT # diff --git a/lnp_ml/interpretability/token_importance.py b/lnp_ml/interpretability/token_importance.py index 40e597a..aa49ebd 100644 --- a/lnp_ml/interpretability/token_importance.py +++ b/lnp_ml/interpretability/token_importance.py @@ -1,25 +1,26 @@ """ -Token-level feature importance for the LNP multi-task model. +Biodistribution feature importance for the LNP multi-task model. -Three complementary methods (inspired by AGILE / Captum): - 1. Integrated Gradients — gradient-based attribution via Captum - 2. Token Ablation — zero-out each token and measure prediction change - 3. Fusion Attention Weights — extract learned attention pooling weights +Two levels of analysis, both targeting the biodist head (per organ): + 1. Token-level — which of the 8 input tokens matters most + 2. Descriptor-level — which RDKit descriptors inside the "desc" token matter most + +Token-level methods: + - Integrated Gradients (Captum) + - Token Ablation (zero-out) + - Fusion Attention Weights Usage: - python -m lnp_ml.interpretability.token_importance \ - --model-path models/model.pt \ - --data-path data/processed/train.parquet \ - --task delivery \ - --method all + python -m lnp_ml.interpretability.token_importance # all organs, token-level IG + python -m lnp_ml.interpretability.token_importance --organ liver # liver only + python -m lnp_ml.interpretability.token_importance --desc-ig # + desc-level IG """ from __future__ import annotations import argparse -import os from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import numpy as np import pandas as pd @@ -35,7 +36,7 @@ from lnp_ml.modeling.predict import load_model from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN -TASKS = ["delivery", "size", "pdi", "ee", "biodist", "toxic"] +BIODIST_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"] def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]: @@ -45,96 +46,97 @@ def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]: # ────────────────────────────────────────────────────────────────────── -# Helper: pre-compute projected tokens for the whole dataset +# Pre-compute # ────────────────────────────────────────────────────────────────────── @torch.no_grad() def precompute_projected_tokens( - model: nn.Module, - loader: DataLoader, - device: torch.device, + model: nn.Module, loader: DataLoader, device: torch.device, ) -> torch.Tensor: - """ - Run encoder + TokenProjector on all samples and stack results. - - Returns: - all_tokens: [N, n_tokens, d_model] - """ + """Returns [N, n_tokens, d_model].""" model.eval() chunks = [] for batch in tqdm(loader, desc="Encoding tokens"): smiles = batch["smiles"] tabular = {k: v.to(device) for k, v in batch["tabular"].items()} - stacked = model._encode_and_project(smiles, tabular) - chunks.append(stacked.cpu()) + chunks.append(model._encode_and_project(smiles, tabular).cpu()) return torch.cat(chunks, dim=0) +@torch.no_grad() +def precompute_raw_desc( + model: nn.Module, loader: DataLoader, device: torch.device, +) -> torch.Tensor: + """Returns [N, desc_dim] raw RDKit descriptor features.""" + model.eval() + chunks = [] + for batch in tqdm(loader, desc="Encoding raw desc"): + smiles = batch["smiles"] + rdkit_features = model.rdkit_encoder(smiles) + target_device = next(model.parameters()).device + chunks.append(rdkit_features["desc"].to(target_device).cpu()) + return torch.cat(chunks, dim=0) + + +def get_desc_names_and_dim(model: nn.Module) -> tuple[List[str], int]: + """Get descriptor names from RDKit and verify against model dim.""" + from rdkit.Chem import Descriptors as RDKitDesc + names = [name for name, _ in RDKitDesc.descList] + model_dim = model.token_projector.projectors["desc"][1].in_features + if len(names) != model_dim: + logger.warning( + f"RDKit descList has {len(names)} names but model expects " + f"{model_dim} features; using generic names." + ) + names = [f"desc_{i}" for i in range(model_dim)] + return names, model_dim + + # ────────────────────────────────────────────────────────────────────── -# Method 1: Integrated Gradients (Captum) +# Token-level: Integrated Gradients # ────────────────────────────────────────────────────────────────────── class _ProjectedWrapper(nn.Module): - """Wraps model.forward_from_projected so Captum can attribute w.r.t. stacked tokens.""" - - def __init__(self, model: nn.Module, task: str) -> None: + def __init__(self, model: nn.Module, organ_index: int) -> None: super().__init__() self.model = model - self.task = task + self.organ_index = organ_index def forward(self, stacked: torch.Tensor) -> torch.Tensor: - """ - Args: - stacked: [B, n_tokens, d_model] - Returns: - [B] scalar predictions (sum over output dims for non-scalar heads) - """ - out = self.model.forward_from_projected(stacked, task=self.task) - if out.dim() > 1 and out.size(-1) > 1: - return out.sum(dim=-1) - return out.squeeze(-1) + out = self.model.forward_from_projected(stacked, task="biodist") + return out[:, self.organ_index] def integrated_gradients_importance( model: nn.Module, all_tokens: torch.Tensor, device: torch.device, - task: str = "delivery", + organ_index: int, batch_size: int = 64, n_steps: int = 50, ) -> np.ndarray: - """ - Compute token-level Integrated Gradients. - - Returns: - importance: [n_tokens] averaged absolute attribution per token position - """ + """Returns [n_tokens] averaged absolute attribution per token.""" from captum.attr import IntegratedGradients - wrapper = _ProjectedWrapper(model, task).to(device) + wrapper = _ProjectedWrapper(model, organ_index).to(device) wrapper.eval() ig = IntegratedGradients(wrapper) N = all_tokens.size(0) all_attrs: List[torch.Tensor] = [] - for start in tqdm(range(0, N, batch_size), desc=f"IG ({task})"): + for start in tqdm(range(0, N, batch_size), desc=f"IG (organ={organ_index})"): end = min(start + batch_size, N) inp = all_tokens[start:end].to(device).requires_grad_(True) baseline = torch.zeros_like(inp) - attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps) - # attr: [B, n_tokens, d_model] → per-token L2 norm → [B, n_tokens] - token_attr = attr.detach().cpu().norm(dim=-1) - all_attrs.append(token_attr) + all_attrs.append(attr.detach().cpu().norm(dim=-1)) # [B, n_tokens] - attrs = torch.cat(all_attrs, dim=0) # [N, n_tokens] - importance = attrs.mean(dim=0).numpy() # [n_tokens] - return importance + return torch.cat(all_attrs, dim=0).mean(dim=0).numpy() # ────────────────────────────────────────────────────────────────────── -# Method 2: Token Ablation (zero-out) +# Token-level: Ablation # ────────────────────────────────────────────────────────────────────── @torch.no_grad() @@ -142,98 +144,113 @@ def token_ablation_importance( model: nn.Module, all_tokens: torch.Tensor, device: torch.device, - task: str = "delivery", + organ_index: int, batch_size: int = 64, ) -> np.ndarray: - """ - For each token position, replace it with zeros and measure - the average absolute prediction change. - - Returns: - importance: [n_tokens] - """ + """Returns [n_tokens] importance via zero-out ablation.""" model.eval() n_tokens = all_tokens.size(1) - N = all_tokens.size(0) - - # original predictions - orig_preds = _batch_predict(model, all_tokens, device, task, batch_size) + orig_preds = _batch_predict(model, all_tokens, device, organ_index, batch_size) importance = np.zeros(n_tokens) for t in range(n_tokens): ablated = all_tokens.clone() ablated[:, t, :] = 0.0 - abl_preds = _batch_predict(model, ablated, device, task, batch_size) + abl_preds = _batch_predict(model, ablated, device, organ_index, batch_size) importance[t] = np.abs(orig_preds - abl_preds).mean() return importance def _batch_predict( - model: nn.Module, - all_tokens: torch.Tensor, - device: torch.device, - task: str, - batch_size: int, + model: nn.Module, all_tokens: torch.Tensor, + device: torch.device, organ_index: int, batch_size: int, ) -> np.ndarray: - """Run forward_from_projected in batches, return [N] predictions.""" model.eval() preds = [] - N = all_tokens.size(0) - for start in range(0, N, batch_size): - end = min(start + batch_size, N) - inp = all_tokens[start:end].to(device) - out = model.forward_from_projected(inp, task=task) - if out.dim() > 1 and out.size(-1) > 1: - out = out.sum(dim=-1) - else: - out = out.squeeze(-1) - preds.append(out.cpu().numpy()) - return np.concatenate(preds, axis=0) + for start in range(0, all_tokens.size(0), batch_size): + end = min(start + batch_size, all_tokens.size(0)) + out = model.forward_from_projected(all_tokens[start:end].to(device), task="biodist") + preds.append(out[:, organ_index].cpu().numpy()) + return np.concatenate(preds) # ────────────────────────────────────────────────────────────────────── -# Method 3: Fusion Attention Weights +# Token-level: Fusion Attention Weights (task-agnostic) # ────────────────────────────────────────────────────────────────────── @torch.no_grad() def fusion_attention_importance( - model: nn.Module, - all_tokens: torch.Tensor, - device: torch.device, - batch_size: int = 64, + model: nn.Module, all_tokens: torch.Tensor, + device: torch.device, batch_size: int = 64, ) -> Optional[np.ndarray]: - """ - Extract FusionLayer attention weights (only works if strategy="attention"). - - Returns: - importance: [n_tokens] averaged attention weights, or None if not applicable. - """ model.eval() if model.fusion.strategy != "attention": - logger.warning("Fusion strategy is not 'attention'; skipping attention weight extraction.") + logger.warning("Fusion strategy is not 'attention'; skipping.") return None - N = all_tokens.size(0) all_weights: List[torch.Tensor] = [] - - for start in range(0, N, batch_size): - end = min(start + batch_size, N) - inp = all_tokens[start:end].to(device) - attended = model.cross_attention(inp) + for start in range(0, all_tokens.size(0), batch_size): + end = min(start + batch_size, all_tokens.size(0)) + attended = model.cross_attention(all_tokens[start:end].to(device)) _, weights = model.fusion(attended, return_attn_weights=True) all_weights.append(weights.cpu()) - weights = torch.cat(all_weights, dim=0) # [N, n_tokens] - return weights.mean(dim=0).numpy() + return torch.cat(all_weights, dim=0).mean(dim=0).numpy() # ────────────────────────────────────────────────────────────────────── -# Bonus: TokenProjector gate values (static) +# Descriptor-level: Integrated Gradients +# ────────────────────────────────────────────────────────────────────── + +class _ReplacingDescWrapper(nn.Module): + def __init__(self, model: nn.Module, base_projected: torch.Tensor, organ_index: int) -> None: + super().__init__() + self.model = model + self.base_projected = base_projected + self.organ_index = organ_index + + def forward(self, raw_desc: torch.Tensor) -> torch.Tensor: + base = self.base_projected[:raw_desc.size(0)].to(raw_desc.device) + out = self.model.forward_replacing_token(raw_desc, "desc", base, task="biodist") + return out[:, self.organ_index] + + +def descriptor_ig_importance( + model: nn.Module, + all_tokens: torch.Tensor, + raw_desc: torch.Tensor, + device: torch.device, + organ_index: int, + batch_size: int = 64, + n_steps: int = 50, +) -> np.ndarray: + """Returns [desc_dim] averaged absolute attribution per descriptor.""" + from captum.attr import IntegratedGradients + + wrapper = _ReplacingDescWrapper(model, all_tokens, organ_index).to(device) + wrapper.eval() + ig = IntegratedGradients(wrapper) + + N = raw_desc.size(0) + all_attrs: List[torch.Tensor] = [] + + for start in tqdm(range(0, N, batch_size), desc=f"Desc-IG (organ={organ_index})"): + end = min(start + batch_size, N) + inp = raw_desc[start:end].to(device).requires_grad_(True) + baseline = torch.zeros_like(inp) + wrapper.base_projected = all_tokens[start:end] + attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps) + all_attrs.append(attr.detach().cpu().abs()) + + return torch.cat(all_attrs, dim=0).mean(dim=0).numpy() + + +# ────────────────────────────────────────────────────────────────────── +# Gate values # ────────────────────────────────────────────────────────────────────── def gate_values(model: nn.Module) -> Dict[str, float]: - """Read sigmoid(weight) from TokenProjector for each token.""" gates = {} for key in model.token_projector.keys: w = model.token_projector.weights[key].detach().cpu() @@ -242,7 +259,7 @@ def gate_values(model: nn.Module) -> Dict[str, float]: # ────────────────────────────────────────────────────────────────────── -# Visualization +# Visualization & IO # ────────────────────────────────────────────────────────────────────── def normalize(arr: np.ndarray) -> np.ndarray: @@ -251,10 +268,8 @@ def normalize(arr: np.ndarray) -> np.ndarray: def plot_token_importance( - results: Dict[str, np.ndarray], - token_names: List[str], - task: str, - out_dir: Path, + results: Dict[str, np.ndarray], token_names: List[str], + organ: str, out_dir: Path, ) -> Path: import matplotlib matplotlib.use("Agg") @@ -265,20 +280,18 @@ def plot_token_importance( if n_methods == 1: axes = [axes] - channel_a_color = "#5a448e" - channel_b_color = "#e07b39" + color_a, color_b = "#5a448e", "#e07b39" for ax, (method_name, importance) in zip(axes, results.items()): normed = normalize(importance) order = np.argsort(normed) - names_sorted = [token_names[i] for i in order] vals_sorted = normed[order] n_tokens = len(token_names) split_idx = 4 if n_tokens == 8 else 3 channel_a_set = set(token_names[:split_idx]) - colors = [channel_a_color if n in channel_a_set else channel_b_color for n in names_sorted] + colors = [color_a if n in channel_a_set else color_b for n in names_sorted] ax.barh(range(len(names_sorted)), vals_sorted, color=colors) ax.set_yticks(range(len(names_sorted))) @@ -288,22 +301,20 @@ def plot_token_importance( ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) - fig.suptitle(f"Token Importance — task: {task}", fontsize=14, y=1.02) + fig.suptitle(f"Token Importance — biodist: {organ}", fontsize=14, y=1.02) fig.tight_layout() out_dir.mkdir(parents=True, exist_ok=True) - fig_path = out_dir / f"token_importance_{task}.png" + fig_path = out_dir / f"token_importance_biodist_{organ}.png" fig.savefig(fig_path, dpi=200, bbox_inches="tight") plt.close(fig) logger.info(f"Figure saved to {fig_path}") return fig_path -def save_csv( - results: Dict[str, np.ndarray], - token_names: List[str], - task: str, - out_dir: Path, +def save_token_csv( + results: Dict[str, np.ndarray], token_names: List[str], + organ: str, out_dir: Path, gate_vals: Optional[Dict[str, float]] = None, ) -> Path: out_dir.mkdir(parents=True, exist_ok=True) @@ -312,38 +323,98 @@ def save_csv( normed = normalize(importance) df[f"{method_name}_raw"] = importance df[f"{method_name}_normalized"] = normed - if gate_vals is not None: df["gate_sigmoid"] = [gate_vals.get(t, float("nan")) for t in token_names] - df = df.sort_values( by=[c for c in df.columns if c.endswith("_normalized")][-1], ascending=False, ) - csv_path = out_dir / f"token_importance_{task}.csv" + csv_path = out_dir / f"token_importance_biodist_{organ}.csv" df.to_csv(csv_path, index=False) logger.info(f"CSV saved to {csv_path}") return csv_path +def plot_descriptor_importance( + importance: np.ndarray, feature_names: List[str], + organ: str, out_dir: Path, top_k: int = 30, +) -> Path: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + normed = normalize(importance) + order = np.argsort(-normed) + top_indices = order[:top_k] + + names = [feature_names[i] for i in reversed(top_indices)] + vals = [normed[i] for i in reversed(top_indices)] + + fig, ax = plt.subplots(figsize=(8, max(5, top_k * 0.35))) + ax.barh(range(len(names)), vals, color="#5a448e") + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=9) + ax.set_xlabel("Normalized Importance", fontsize=11) + ax.set_title( + f"Top-{top_k} desc Feature Importance — biodist: {organ}\n" + f"(total features: {len(feature_names)})", + fontsize=13, + ) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + fig.tight_layout() + + out_dir.mkdir(parents=True, exist_ok=True) + fig_path = out_dir / f"desc_importance_biodist_{organ}.png" + fig.savefig(fig_path, dpi=200, bbox_inches="tight") + plt.close(fig) + logger.info(f"Descriptor figure saved to {fig_path}") + return fig_path + + +def save_descriptor_csv( + importance: np.ndarray, feature_names: List[str], + organ: str, out_dir: Path, +) -> Path: + out_dir.mkdir(parents=True, exist_ok=True) + normed = normalize(importance) + df = pd.DataFrame({ + "feature": feature_names, + "ig_raw": importance, + "ig_normalized": normed, + }) + df = df.sort_values("ig_normalized", ascending=False).reset_index(drop=True) + df.index.name = "rank" + csv_path = out_dir / f"desc_importance_biodist_{organ}.csv" + df.to_csv(csv_path) + logger.info(f"Descriptor CSV saved to {csv_path}") + return csv_path + + # ────────────────────────────────────────────────────────────────────── # Main # ────────────────────────────────────────────────────────────────────── def main() -> None: - parser = argparse.ArgumentParser(description="Token-level Feature Importance") - parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "final" / "model.pt"), - help="Path to trained model checkpoint") - parser.add_argument("--data-path", type=str, default=str(INTERIM_DATA_DIR / "internal.csv"), - help="Path to data (.csv or .parquet) for computing importance") - parser.add_argument("--task", type=str, default="all", choices=TASKS + ["all"], - help="Target task for importance computation ('all' to run on all tasks)") + parser = argparse.ArgumentParser( + description="Biodistribution feature importance (token-level & descriptor-level)") + parser.add_argument("--model-path", type=str, + default=str(MODELS_DIR / "final" / "model.pt")) + parser.add_argument("--data-path", type=str, + default=str(INTERIM_DATA_DIR / "internal.csv")) + parser.add_argument("--organ", type=str, default="all", + choices=BIODIST_ORGANS + ["all"], + help="Organ to analyze (default: all 7 organs)") parser.add_argument("--method", type=str, default="ig", choices=["ig", "ablation", "attention", "all"], - help="Which method(s) to run") + help="Token-level method(s)") + parser.add_argument("--desc-ig", action="store_true", + help="Also compute descriptor-level IG within the desc token") + parser.add_argument("--desc-top-k", type=int, default=30, + help="Top-K descriptors to show in plot") parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--n-steps", type=int, default=50, - help="Number of interpolation steps for Integrated Gradients") + help="Interpolation steps for Integrated Gradients") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--output-dir", type=str, @@ -352,17 +423,21 @@ def main() -> None: device = torch.device(args.device) out_dir = Path(args.output_dir) + + # ── Resolve organs ── + organs = BIODIST_ORGANS if args.organ == "all" else [args.organ] + organ_indices = {organ: BIODIST_ORGANS.index(organ) for organ in organs} + logger.info(f"Device: {device}") logger.info(f"Model: {args.model_path}") logger.info(f"Data: {args.data_path}") - logger.info(f"Task: {args.task}") + logger.info(f"Organs: {organs}") - # ── Load model ── + # ── Load model & data ── model = load_model(Path(args.model_path), device) token_names = get_token_names(model) logger.info(f"Tokens ({len(token_names)}): {token_names}") - # ── Load data ── data_path = Path(args.data_path) if data_path.suffix == ".csv": df = pd.read_csv(data_path) @@ -373,76 +448,78 @@ def main() -> None: loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) logger.info(f"Samples: {len(dataset)}") - # ── Pre-compute projected tokens ── + # ── Pre-compute ── all_tokens = precompute_projected_tokens(model, loader, device) logger.info(f"Projected tokens shape: {all_tokens.shape}") - # ── Gate values (always) ── gv = gate_values(model) - logger.info(f"TokenProjector gate values: {gv}") + logger.info(f"Gate values: {gv}") - # ── Run selected methods ── - methods = ( - ["ig", "ablation", "attention"] if args.method == "all" - else [args.method] - ) + methods = ["ig", "ablation", "attention"] if args.method == "all" else [args.method] - # Determine tasks to process - tasks_to_run = TASKS if args.task == "all" else [args.task] + # ── Pre-compute desc features if needed ── + raw_desc = None + desc_names = None + if args.desc_ig: + desc_names, desc_dim = get_desc_names_and_dim(model) + raw_desc = precompute_raw_desc(model, loader, device) + logger.info(f"Raw desc shape: {raw_desc.shape} (dim={desc_dim})") + assert raw_desc.size(1) == desc_dim - for task in tasks_to_run: + # ── Per-organ loop ── + for organ, organ_idx in organ_indices.items(): logger.info(f"\n{'#'*60}") - logger.info(f"# Processing task: {task}") + logger.info(f"# Organ: {organ} (index={organ_idx})") logger.info(f"{'#'*60}") + # Token-level results: Dict[str, np.ndarray] = {} - for method in methods: - logger.info(f"\n{'='*60}") - logger.info(f"Computing: {method} (task={task})") - logger.info(f"{'='*60}") - + logger.info(f" Computing: {method}") if method == "ig": - imp = integrated_gradients_importance( - model, all_tokens, device, - task=task, batch_size=args.batch_size, n_steps=args.n_steps, + results["Integrated Gradients"] = integrated_gradients_importance( + model, all_tokens, device, organ_idx, + batch_size=args.batch_size, n_steps=args.n_steps, ) - results["Integrated Gradients"] = imp - elif method == "ablation": - imp = token_ablation_importance( - model, all_tokens, device, - task=task, batch_size=args.batch_size, - ) - results["Token Ablation"] = imp - - elif method == "attention": - imp = fusion_attention_importance( - model, all_tokens, device, + results["Token Ablation"] = token_ablation_importance( + model, all_tokens, device, organ_idx, batch_size=args.batch_size, ) + elif method == "attention": + imp = fusion_attention_importance(model, all_tokens, device, args.batch_size) if imp is not None: results["Fusion Attention"] = imp - # ── Print summary ── - logger.info(f"\n{'='*60}") - logger.info(f"Token Importance Summary (task={task})") - logger.info(f"{'='*60}") + # Print token summary for method_name, importance in results.items(): normed = normalize(importance) order = np.argsort(-normed) - logger.info(f"\n {method_name}:") + logger.info(f"\n {method_name} (biodist_{organ}):") for rank, idx in enumerate(order, 1): logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}") - logger.info(f"\n Gate values (sigmoid):") - for name, val in sorted(gv.items(), key=lambda x: -x[1]): - logger.info(f" {name:<10s} {val:.4f}") - - # ── Save results ── if results: - plot_token_importance(results, token_names, task, out_dir) - save_csv(results, token_names, task, out_dir, gate_vals=gv) + plot_token_importance(results, token_names, organ, out_dir) + save_token_csv(results, token_names, organ, out_dir, gate_vals=gv) + + # Descriptor-level + if args.desc_ig and raw_desc is not None and desc_names is not None: + logger.info(f"\n Descriptor IG for {organ}...") + desc_imp = descriptor_ig_importance( + model, all_tokens, raw_desc, device, organ_idx, + batch_size=args.batch_size, n_steps=args.n_steps, + ) + normed = normalize(desc_imp) + top_k = min(args.desc_top_k, len(desc_names)) + order = np.argsort(-normed) + logger.info(f"\n Top-{top_k} desc features (biodist_{organ}):") + for rank, idx in enumerate(order[:top_k], 1): + logger.info(f" {rank:>3d}. {desc_names[idx]:<30s} {normed[idx]:.6f}") + + desc_out_dir = out_dir / "desc" + plot_descriptor_importance(desc_imp, desc_names, organ, desc_out_dir, top_k) + save_descriptor_csv(desc_imp, desc_names, organ, desc_out_dir) logger.info("\nDone!") diff --git a/lnp_ml/modeling/models.py b/lnp_ml/modeling/models.py index 893bd62..c85a609 100644 --- a/lnp_ml/modeling/models.py +++ b/lnp_ml/modeling/models.py @@ -201,6 +201,39 @@ class LNPModel(nn.Module): } return task_heads[task](fused) + def forward_replacing_token( + self, + raw_feature: torch.Tensor, + feature_key: str, + base_projected: torch.Tensor, + task: Optional[str] = None, + ) -> torch.Tensor: + """ + 用原始特征替换 base_projected 中指定 token 的投影,然后 forward。 + + 用于对单个 token 内部特征做 Captum 归因(如 desc 的 210 维)。 + + Args: + raw_feature: [B, input_dim] 某个 token 的原始特征 + feature_key: token 名称,如 "desc" + base_projected: [B, n_tokens, d_model] 其他 token 已投影好的张量 + task: 任务名 + + Returns: + 对应任务的预测输出 + """ + projected = self.token_projector.projectors[feature_key](raw_feature) + gate = torch.sigmoid(self.token_projector.weights[feature_key]) + projected = projected * gate # [B, d_model] + + token_order = list(self.token_projector.keys) + token_idx = token_order.index(feature_key) + + stacked = base_projected.clone() + stacked[:, token_idx, :] = projected + + return self.forward_from_projected(stacked, task=task) + def forward_backbone( self, smiles: List[str], diff --git a/reports/feature_importance/token_importance_biodist.csv b/reports/feature_importance/token_importance_biodist.csv deleted file mode 100644 index 72d819f..0000000 --- a/reports/feature_importance/token_importance_biodist.csv +++ /dev/null @@ -1,9 +0,0 @@ -token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid -desc,3.2013970842762367e-09,0.17906809140868585,0.5030443072319031 -mpnn,3.109777150387161e-09,0.17394338920379962,0.5024935007095337 -maccs,3.0657202874248063e-09,0.1714790968475434,0.5030479431152344 -morgan,3.020539287877718e-09,0.16895192663283606,0.5045571327209473 -help,1.937320535640997e-09,0.10836278088337024,0.49689680337905884 -comp,1.876732953403087e-09,0.10497385335304418,0.5007365345954895 -exp,1.6503372406931126e-09,0.09231055445232395,0.5002157688140869 -phys,1.6274562664095572e-11,0.0009103072183966515,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_biodist.png b/reports/feature_importance/token_importance_biodist.png deleted file mode 100644 index 8e316a5..0000000 Binary files a/reports/feature_importance/token_importance_biodist.png and /dev/null differ diff --git a/reports/feature_importance/token_importance_delivery.csv b/reports/feature_importance/token_importance_delivery.csv deleted file mode 100644 index c2aa6ec..0000000 --- a/reports/feature_importance/token_importance_delivery.csv +++ /dev/null @@ -1,9 +0,0 @@ -token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid -desc,0.48309860429978513,0.1830685579048154,0.5030443072319031 -mpnn,0.4800125818662255,0.1818991202961257,0.5024935007095337 -morgan,0.4681746999619525,0.17741319558101734,0.5045571327209473 -maccs,0.4642216987644718,0.1759152193455701,0.5030479431152344 -help,0.2636187040391749,0.09989740304701922,0.49689680337905884 -comp,0.2503036292270154,0.0948517011498054,0.5007365345954895 -exp,0.22730758172642437,0.08613742788142016,0.5002157688140869 -phys,0.0021569658208921167,0.0008173747942266034,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_delivery.png b/reports/feature_importance/token_importance_delivery.png deleted file mode 100644 index f01e9bb..0000000 Binary files a/reports/feature_importance/token_importance_delivery.png and /dev/null differ diff --git a/reports/feature_importance/token_importance_ee.csv b/reports/feature_importance/token_importance_ee.csv deleted file mode 100644 index efd3a51..0000000 --- a/reports/feature_importance/token_importance_ee.csv +++ /dev/null @@ -1,9 +0,0 @@ -token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid -desc,0.28693426746338735,0.2025623449841757,0.5030443072319031 -mpnn,0.2846832493908164,0.2009732301551499,0.5024935007095337 -morgan,0.2646961202937205,0.18686323982460912,0.5045571327209473 -maccs,0.25845640337994125,0.1824582877731647,0.5030479431152344 -help,0.11619691669315178,0.082029658337337,0.49689680337905884 -comp,0.10812802919624785,0.07633339630758515,0.5007365345954895 -exp,0.09640860187977682,0.06806002171178546,0.5002157688140869 -phys,0.001019643036021533,0.000719820906192951,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_ee.png b/reports/feature_importance/token_importance_ee.png deleted file mode 100644 index 43b9096..0000000 Binary files a/reports/feature_importance/token_importance_ee.png and /dev/null differ diff --git a/reports/feature_importance/token_importance_pdi.csv b/reports/feature_importance/token_importance_pdi.csv deleted file mode 100644 index 30d6c03..0000000 --- a/reports/feature_importance/token_importance_pdi.csv +++ /dev/null @@ -1,9 +0,0 @@ -token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid -mpnn,0.29288250773521896,0.2140725741095052,0.5024935007095337 -desc,0.2764989421166159,0.2020975603328644,0.5030443072319031 -morgan,0.2578500855439716,0.18846680866532511,0.5045571327209473 -maccs,0.2471777274240024,0.18066620905892686,0.5030479431152344 -help,0.10363027887841153,0.07574505123823679,0.49689680337905884 -comp,0.09781463069792998,0.07149429968008479,0.5007365345954895 -exp,0.091429982122356,0.06682765650659303,0.5002157688140869 -phys,0.0008617135523839594,0.0006298404084638948,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_pdi.png b/reports/feature_importance/token_importance_pdi.png deleted file mode 100644 index d683bf3..0000000 Binary files a/reports/feature_importance/token_importance_pdi.png and /dev/null differ diff --git a/reports/feature_importance/token_importance_size.csv b/reports/feature_importance/token_importance_size.csv deleted file mode 100644 index 786b3cf..0000000 --- a/reports/feature_importance/token_importance_size.csv +++ /dev/null @@ -1,9 +0,0 @@ -token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid -desc,0.43124767207505,0.1973483310882186,0.5030443072319031 -mpnn,0.4250195774970647,0.1944982193069518,0.5024935007095337 -morgan,0.41677666295297405,0.19072608200879151,0.5045571327209473 -maccs,0.40606126033119555,0.1858224802938626,0.5030479431152344 -help,0.17471509961619794,0.07995343640757792,0.49689680337905884 -comp,0.16962298878652332,0.07762317554120124,0.5007365345954895 -exp,0.16035562746707094,0.07338222907722322,0.5002157688140869 -phys,0.0014117471939899774,0.0006460462761730848,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_size.png b/reports/feature_importance/token_importance_size.png deleted file mode 100644 index bf96cab..0000000 Binary files a/reports/feature_importance/token_importance_size.png and /dev/null differ diff --git a/reports/feature_importance/token_importance_toxic.csv b/reports/feature_importance/token_importance_toxic.csv deleted file mode 100644 index 85912fb..0000000 --- a/reports/feature_importance/token_importance_toxic.csv +++ /dev/null @@ -1,9 +0,0 @@ -token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid -mpnn,0.17711989597355013,0.21006733816477163,0.5024935007095337 -desc,0.17162019114321028,0.20354459068882486,0.5030443072319031 -morgan,0.16095514462538474,0.19089565635488848,0.5045571327209473 -maccs,0.15742516053846328,0.18670903261717847,0.5030479431152344 -comp,0.05954181250470194,0.07061764571178654,0.5007365345954895 -help,0.05892528920918569,0.06988643814815398,0.49689680337905884 -exp,0.05708834082906645,0.06770778478774685,0.5002157688140869 -phys,0.0004818760368551862,0.0005715135266490605,0.49989768862724304 diff --git a/reports/feature_importance/token_importance_toxic.png b/reports/feature_importance/token_importance_toxic.png deleted file mode 100644 index 78835a2..0000000 Binary files a/reports/feature_importance/token_importance_toxic.png and /dev/null differ