移除无关任务

This commit is contained in:
RYDE-WORK 2026-03-03 15:04:13 +08:00
parent 3b38727053
commit ac5f598484
15 changed files with 312 additions and 245 deletions

View File

@ -123,25 +123,36 @@ train: requirements
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG) $(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG)
################################################################################# #################################################################################
# INTERPRETABILITY # # INTERPRETABILITY (biodistribution feature importance) #
################################################################################# #################################################################################
# 参数: # 参数:
# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic, all; 默认: delivery) # ORGAN 器官 (lymph_nodes, heart, liver, spleen, lung, kidney, muscle, all; 默认: all)
# 如果指定 'all',将依次计算所有 6 个任务 # METHOD token 级方法 (ig, ablation, attention, all; 默认: ig)
# METHOD 方法 (ig, ablation, attention, all; 默认: ig) # DESC_IG 同时计算 desc 内部特征 IG (1 启用; 默认不启用)
# DATA 数据路径 (默认: data/interim/internal.csv即最终模型的全量训练数据) # DESC_TOP_K 可视化展示的 top-K 特征数 (默认: 30)
# DATA 数据路径 (默认: data/interim/internal.csv)
# MODEL 模型路径 (默认: models/final/model.pt) # MODEL 模型路径 (默认: models/final/model.pt)
TASK_FLAG = $(if $(TASK),--task $(TASK),)
METHOD_FLAG = $(if $(METHOD),--method $(METHOD),) METHOD_FLAG = $(if $(METHOD),--method $(METHOD),)
DATA_FLAG = $(if $(DATA),--data-path $(DATA),) DATA_FLAG = $(if $(DATA),--data-path $(DATA),)
MODEL_FLAG = $(if $(MODEL),--model-path $(MODEL),) MODEL_FLAG = $(if $(MODEL),--model-path $(MODEL),)
ORGAN_FLAG = $(if $(ORGAN),--organ $(ORGAN),)
DESC_IG_FLAG = $(if $(DESC_IG),--desc-ig,)
DESC_TOP_K_FLAG = $(if $(DESC_TOP_K),--desc-top-k $(DESC_TOP_K),)
## Compute token-level feature importance ## Compute biodistribution feature importance (token-level)
.PHONY: feature_importance .PHONY: feature_importance
feature_importance: requirements feature_importance: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \ $(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \
$(TASK_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG) $(ORGAN_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG) \
$(DESC_IG_FLAG) $(DESC_TOP_K_FLAG)
## Compute biodistribution feature importance (token + descriptor-level)
.PHONY: desc_importance
desc_importance: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \
--desc-ig $(ORGAN_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) \
$(DEVICE_FLAG) $(DESC_TOP_K_FLAG)
################################################################################# #################################################################################
# SERVING & DEPLOYMENT # # SERVING & DEPLOYMENT #

View File

@ -1,25 +1,26 @@
""" """
Token-level feature importance for the LNP multi-task model. Biodistribution feature importance for the LNP multi-task model.
Three complementary methods (inspired by AGILE / Captum): Two levels of analysis, both targeting the biodist head (per organ):
1. Integrated Gradients gradient-based attribution via Captum 1. Token-level which of the 8 input tokens matters most
2. Token Ablation zero-out each token and measure prediction change 2. Descriptor-level which RDKit descriptors inside the "desc" token matter most
3. Fusion Attention Weights extract learned attention pooling weights
Token-level methods:
- Integrated Gradients (Captum)
- Token Ablation (zero-out)
- Fusion Attention Weights
Usage: Usage:
python -m lnp_ml.interpretability.token_importance \ python -m lnp_ml.interpretability.token_importance # all organs, token-level IG
--model-path models/model.pt \ python -m lnp_ml.interpretability.token_importance --organ liver # liver only
--data-path data/processed/train.parquet \ python -m lnp_ml.interpretability.token_importance --desc-ig # + desc-level IG
--task delivery \
--method all
""" """
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import os
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -35,7 +36,7 @@ from lnp_ml.modeling.predict import load_model
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
TASKS = ["delivery", "size", "pdi", "ee", "biodist", "toxic"] BIODIST_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]: def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]:
@ -45,96 +46,97 @@ def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]:
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
# Helper: pre-compute projected tokens for the whole dataset # Pre-compute
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
@torch.no_grad() @torch.no_grad()
def precompute_projected_tokens( def precompute_projected_tokens(
model: nn.Module, model: nn.Module, loader: DataLoader, device: torch.device,
loader: DataLoader,
device: torch.device,
) -> torch.Tensor: ) -> torch.Tensor:
""" """Returns [N, n_tokens, d_model]."""
Run encoder + TokenProjector on all samples and stack results.
Returns:
all_tokens: [N, n_tokens, d_model]
"""
model.eval() model.eval()
chunks = [] chunks = []
for batch in tqdm(loader, desc="Encoding tokens"): for batch in tqdm(loader, desc="Encoding tokens"):
smiles = batch["smiles"] smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()} tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
stacked = model._encode_and_project(smiles, tabular) chunks.append(model._encode_and_project(smiles, tabular).cpu())
chunks.append(stacked.cpu())
return torch.cat(chunks, dim=0) return torch.cat(chunks, dim=0)
@torch.no_grad()
def precompute_raw_desc(
model: nn.Module, loader: DataLoader, device: torch.device,
) -> torch.Tensor:
"""Returns [N, desc_dim] raw RDKit descriptor features."""
model.eval()
chunks = []
for batch in tqdm(loader, desc="Encoding raw desc"):
smiles = batch["smiles"]
rdkit_features = model.rdkit_encoder(smiles)
target_device = next(model.parameters()).device
chunks.append(rdkit_features["desc"].to(target_device).cpu())
return torch.cat(chunks, dim=0)
def get_desc_names_and_dim(model: nn.Module) -> tuple[List[str], int]:
"""Get descriptor names from RDKit and verify against model dim."""
from rdkit.Chem import Descriptors as RDKitDesc
names = [name for name, _ in RDKitDesc.descList]
model_dim = model.token_projector.projectors["desc"][1].in_features
if len(names) != model_dim:
logger.warning(
f"RDKit descList has {len(names)} names but model expects "
f"{model_dim} features; using generic names."
)
names = [f"desc_{i}" for i in range(model_dim)]
return names, model_dim
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
# Method 1: Integrated Gradients (Captum) # Token-level: Integrated Gradients
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
class _ProjectedWrapper(nn.Module): class _ProjectedWrapper(nn.Module):
"""Wraps model.forward_from_projected so Captum can attribute w.r.t. stacked tokens.""" def __init__(self, model: nn.Module, organ_index: int) -> None:
def __init__(self, model: nn.Module, task: str) -> None:
super().__init__() super().__init__()
self.model = model self.model = model
self.task = task self.organ_index = organ_index
def forward(self, stacked: torch.Tensor) -> torch.Tensor: def forward(self, stacked: torch.Tensor) -> torch.Tensor:
""" out = self.model.forward_from_projected(stacked, task="biodist")
Args: return out[:, self.organ_index]
stacked: [B, n_tokens, d_model]
Returns:
[B] scalar predictions (sum over output dims for non-scalar heads)
"""
out = self.model.forward_from_projected(stacked, task=self.task)
if out.dim() > 1 and out.size(-1) > 1:
return out.sum(dim=-1)
return out.squeeze(-1)
def integrated_gradients_importance( def integrated_gradients_importance(
model: nn.Module, model: nn.Module,
all_tokens: torch.Tensor, all_tokens: torch.Tensor,
device: torch.device, device: torch.device,
task: str = "delivery", organ_index: int,
batch_size: int = 64, batch_size: int = 64,
n_steps: int = 50, n_steps: int = 50,
) -> np.ndarray: ) -> np.ndarray:
""" """Returns [n_tokens] averaged absolute attribution per token."""
Compute token-level Integrated Gradients.
Returns:
importance: [n_tokens] averaged absolute attribution per token position
"""
from captum.attr import IntegratedGradients from captum.attr import IntegratedGradients
wrapper = _ProjectedWrapper(model, task).to(device) wrapper = _ProjectedWrapper(model, organ_index).to(device)
wrapper.eval() wrapper.eval()
ig = IntegratedGradients(wrapper) ig = IntegratedGradients(wrapper)
N = all_tokens.size(0) N = all_tokens.size(0)
all_attrs: List[torch.Tensor] = [] all_attrs: List[torch.Tensor] = []
for start in tqdm(range(0, N, batch_size), desc=f"IG ({task})"): for start in tqdm(range(0, N, batch_size), desc=f"IG (organ={organ_index})"):
end = min(start + batch_size, N) end = min(start + batch_size, N)
inp = all_tokens[start:end].to(device).requires_grad_(True) inp = all_tokens[start:end].to(device).requires_grad_(True)
baseline = torch.zeros_like(inp) baseline = torch.zeros_like(inp)
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps) attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
# attr: [B, n_tokens, d_model] → per-token L2 norm → [B, n_tokens] all_attrs.append(attr.detach().cpu().norm(dim=-1)) # [B, n_tokens]
token_attr = attr.detach().cpu().norm(dim=-1)
all_attrs.append(token_attr)
attrs = torch.cat(all_attrs, dim=0) # [N, n_tokens] return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
importance = attrs.mean(dim=0).numpy() # [n_tokens]
return importance
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
# Method 2: Token Ablation (zero-out) # Token-level: Ablation
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
@torch.no_grad() @torch.no_grad()
@ -142,98 +144,113 @@ def token_ablation_importance(
model: nn.Module, model: nn.Module,
all_tokens: torch.Tensor, all_tokens: torch.Tensor,
device: torch.device, device: torch.device,
task: str = "delivery", organ_index: int,
batch_size: int = 64, batch_size: int = 64,
) -> np.ndarray: ) -> np.ndarray:
""" """Returns [n_tokens] importance via zero-out ablation."""
For each token position, replace it with zeros and measure
the average absolute prediction change.
Returns:
importance: [n_tokens]
"""
model.eval() model.eval()
n_tokens = all_tokens.size(1) n_tokens = all_tokens.size(1)
N = all_tokens.size(0) orig_preds = _batch_predict(model, all_tokens, device, organ_index, batch_size)
# original predictions
orig_preds = _batch_predict(model, all_tokens, device, task, batch_size)
importance = np.zeros(n_tokens) importance = np.zeros(n_tokens)
for t in range(n_tokens): for t in range(n_tokens):
ablated = all_tokens.clone() ablated = all_tokens.clone()
ablated[:, t, :] = 0.0 ablated[:, t, :] = 0.0
abl_preds = _batch_predict(model, ablated, device, task, batch_size) abl_preds = _batch_predict(model, ablated, device, organ_index, batch_size)
importance[t] = np.abs(orig_preds - abl_preds).mean() importance[t] = np.abs(orig_preds - abl_preds).mean()
return importance return importance
def _batch_predict( def _batch_predict(
model: nn.Module, model: nn.Module, all_tokens: torch.Tensor,
all_tokens: torch.Tensor, device: torch.device, organ_index: int, batch_size: int,
device: torch.device,
task: str,
batch_size: int,
) -> np.ndarray: ) -> np.ndarray:
"""Run forward_from_projected in batches, return [N] predictions."""
model.eval() model.eval()
preds = [] preds = []
N = all_tokens.size(0) for start in range(0, all_tokens.size(0), batch_size):
for start in range(0, N, batch_size): end = min(start + batch_size, all_tokens.size(0))
end = min(start + batch_size, N) out = model.forward_from_projected(all_tokens[start:end].to(device), task="biodist")
inp = all_tokens[start:end].to(device) preds.append(out[:, organ_index].cpu().numpy())
out = model.forward_from_projected(inp, task=task) return np.concatenate(preds)
if out.dim() > 1 and out.size(-1) > 1:
out = out.sum(dim=-1)
else:
out = out.squeeze(-1)
preds.append(out.cpu().numpy())
return np.concatenate(preds, axis=0)
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
# Method 3: Fusion Attention Weights # Token-level: Fusion Attention Weights (task-agnostic)
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
@torch.no_grad() @torch.no_grad()
def fusion_attention_importance( def fusion_attention_importance(
model: nn.Module, model: nn.Module, all_tokens: torch.Tensor,
all_tokens: torch.Tensor, device: torch.device, batch_size: int = 64,
device: torch.device,
batch_size: int = 64,
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
"""
Extract FusionLayer attention weights (only works if strategy="attention").
Returns:
importance: [n_tokens] averaged attention weights, or None if not applicable.
"""
model.eval() model.eval()
if model.fusion.strategy != "attention": if model.fusion.strategy != "attention":
logger.warning("Fusion strategy is not 'attention'; skipping attention weight extraction.") logger.warning("Fusion strategy is not 'attention'; skipping.")
return None return None
N = all_tokens.size(0)
all_weights: List[torch.Tensor] = [] all_weights: List[torch.Tensor] = []
for start in range(0, all_tokens.size(0), batch_size):
for start in range(0, N, batch_size): end = min(start + batch_size, all_tokens.size(0))
end = min(start + batch_size, N) attended = model.cross_attention(all_tokens[start:end].to(device))
inp = all_tokens[start:end].to(device)
attended = model.cross_attention(inp)
_, weights = model.fusion(attended, return_attn_weights=True) _, weights = model.fusion(attended, return_attn_weights=True)
all_weights.append(weights.cpu()) all_weights.append(weights.cpu())
weights = torch.cat(all_weights, dim=0) # [N, n_tokens] return torch.cat(all_weights, dim=0).mean(dim=0).numpy()
return weights.mean(dim=0).numpy()
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
# Bonus: TokenProjector gate values (static) # Descriptor-level: Integrated Gradients
# ──────────────────────────────────────────────────────────────────────
class _ReplacingDescWrapper(nn.Module):
def __init__(self, model: nn.Module, base_projected: torch.Tensor, organ_index: int) -> None:
super().__init__()
self.model = model
self.base_projected = base_projected
self.organ_index = organ_index
def forward(self, raw_desc: torch.Tensor) -> torch.Tensor:
base = self.base_projected[:raw_desc.size(0)].to(raw_desc.device)
out = self.model.forward_replacing_token(raw_desc, "desc", base, task="biodist")
return out[:, self.organ_index]
def descriptor_ig_importance(
model: nn.Module,
all_tokens: torch.Tensor,
raw_desc: torch.Tensor,
device: torch.device,
organ_index: int,
batch_size: int = 64,
n_steps: int = 50,
) -> np.ndarray:
"""Returns [desc_dim] averaged absolute attribution per descriptor."""
from captum.attr import IntegratedGradients
wrapper = _ReplacingDescWrapper(model, all_tokens, organ_index).to(device)
wrapper.eval()
ig = IntegratedGradients(wrapper)
N = raw_desc.size(0)
all_attrs: List[torch.Tensor] = []
for start in tqdm(range(0, N, batch_size), desc=f"Desc-IG (organ={organ_index})"):
end = min(start + batch_size, N)
inp = raw_desc[start:end].to(device).requires_grad_(True)
baseline = torch.zeros_like(inp)
wrapper.base_projected = all_tokens[start:end]
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
all_attrs.append(attr.detach().cpu().abs())
return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
# ──────────────────────────────────────────────────────────────────────
# Gate values
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
def gate_values(model: nn.Module) -> Dict[str, float]: def gate_values(model: nn.Module) -> Dict[str, float]:
"""Read sigmoid(weight) from TokenProjector for each token."""
gates = {} gates = {}
for key in model.token_projector.keys: for key in model.token_projector.keys:
w = model.token_projector.weights[key].detach().cpu() w = model.token_projector.weights[key].detach().cpu()
@ -242,7 +259,7 @@ def gate_values(model: nn.Module) -> Dict[str, float]:
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
# Visualization # Visualization & IO
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
def normalize(arr: np.ndarray) -> np.ndarray: def normalize(arr: np.ndarray) -> np.ndarray:
@ -251,10 +268,8 @@ def normalize(arr: np.ndarray) -> np.ndarray:
def plot_token_importance( def plot_token_importance(
results: Dict[str, np.ndarray], results: Dict[str, np.ndarray], token_names: List[str],
token_names: List[str], organ: str, out_dir: Path,
task: str,
out_dir: Path,
) -> Path: ) -> Path:
import matplotlib import matplotlib
matplotlib.use("Agg") matplotlib.use("Agg")
@ -265,20 +280,18 @@ def plot_token_importance(
if n_methods == 1: if n_methods == 1:
axes = [axes] axes = [axes]
channel_a_color = "#5a448e" color_a, color_b = "#5a448e", "#e07b39"
channel_b_color = "#e07b39"
for ax, (method_name, importance) in zip(axes, results.items()): for ax, (method_name, importance) in zip(axes, results.items()):
normed = normalize(importance) normed = normalize(importance)
order = np.argsort(normed) order = np.argsort(normed)
names_sorted = [token_names[i] for i in order] names_sorted = [token_names[i] for i in order]
vals_sorted = normed[order] vals_sorted = normed[order]
n_tokens = len(token_names) n_tokens = len(token_names)
split_idx = 4 if n_tokens == 8 else 3 split_idx = 4 if n_tokens == 8 else 3
channel_a_set = set(token_names[:split_idx]) channel_a_set = set(token_names[:split_idx])
colors = [channel_a_color if n in channel_a_set else channel_b_color for n in names_sorted] colors = [color_a if n in channel_a_set else color_b for n in names_sorted]
ax.barh(range(len(names_sorted)), vals_sorted, color=colors) ax.barh(range(len(names_sorted)), vals_sorted, color=colors)
ax.set_yticks(range(len(names_sorted))) ax.set_yticks(range(len(names_sorted)))
@ -288,22 +301,20 @@ def plot_token_importance(
ax.spines["top"].set_visible(False) ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False) ax.spines["right"].set_visible(False)
fig.suptitle(f"Token Importance — task: {task}", fontsize=14, y=1.02) fig.suptitle(f"Token Importance — biodist: {organ}", fontsize=14, y=1.02)
fig.tight_layout() fig.tight_layout()
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
fig_path = out_dir / f"token_importance_{task}.png" fig_path = out_dir / f"token_importance_biodist_{organ}.png"
fig.savefig(fig_path, dpi=200, bbox_inches="tight") fig.savefig(fig_path, dpi=200, bbox_inches="tight")
plt.close(fig) plt.close(fig)
logger.info(f"Figure saved to {fig_path}") logger.info(f"Figure saved to {fig_path}")
return fig_path return fig_path
def save_csv( def save_token_csv(
results: Dict[str, np.ndarray], results: Dict[str, np.ndarray], token_names: List[str],
token_names: List[str], organ: str, out_dir: Path,
task: str,
out_dir: Path,
gate_vals: Optional[Dict[str, float]] = None, gate_vals: Optional[Dict[str, float]] = None,
) -> Path: ) -> Path:
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
@ -312,38 +323,98 @@ def save_csv(
normed = normalize(importance) normed = normalize(importance)
df[f"{method_name}_raw"] = importance df[f"{method_name}_raw"] = importance
df[f"{method_name}_normalized"] = normed df[f"{method_name}_normalized"] = normed
if gate_vals is not None: if gate_vals is not None:
df["gate_sigmoid"] = [gate_vals.get(t, float("nan")) for t in token_names] df["gate_sigmoid"] = [gate_vals.get(t, float("nan")) for t in token_names]
df = df.sort_values( df = df.sort_values(
by=[c for c in df.columns if c.endswith("_normalized")][-1], by=[c for c in df.columns if c.endswith("_normalized")][-1],
ascending=False, ascending=False,
) )
csv_path = out_dir / f"token_importance_{task}.csv" csv_path = out_dir / f"token_importance_biodist_{organ}.csv"
df.to_csv(csv_path, index=False) df.to_csv(csv_path, index=False)
logger.info(f"CSV saved to {csv_path}") logger.info(f"CSV saved to {csv_path}")
return csv_path return csv_path
def plot_descriptor_importance(
importance: np.ndarray, feature_names: List[str],
organ: str, out_dir: Path, top_k: int = 30,
) -> Path:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
normed = normalize(importance)
order = np.argsort(-normed)
top_indices = order[:top_k]
names = [feature_names[i] for i in reversed(top_indices)]
vals = [normed[i] for i in reversed(top_indices)]
fig, ax = plt.subplots(figsize=(8, max(5, top_k * 0.35)))
ax.barh(range(len(names)), vals, color="#5a448e")
ax.set_yticks(range(len(names)))
ax.set_yticklabels(names, fontsize=9)
ax.set_xlabel("Normalized Importance", fontsize=11)
ax.set_title(
f"Top-{top_k} desc Feature Importance — biodist: {organ}\n"
f"(total features: {len(feature_names)})",
fontsize=13,
)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
fig.tight_layout()
out_dir.mkdir(parents=True, exist_ok=True)
fig_path = out_dir / f"desc_importance_biodist_{organ}.png"
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
plt.close(fig)
logger.info(f"Descriptor figure saved to {fig_path}")
return fig_path
def save_descriptor_csv(
importance: np.ndarray, feature_names: List[str],
organ: str, out_dir: Path,
) -> Path:
out_dir.mkdir(parents=True, exist_ok=True)
normed = normalize(importance)
df = pd.DataFrame({
"feature": feature_names,
"ig_raw": importance,
"ig_normalized": normed,
})
df = df.sort_values("ig_normalized", ascending=False).reset_index(drop=True)
df.index.name = "rank"
csv_path = out_dir / f"desc_importance_biodist_{organ}.csv"
df.to_csv(csv_path)
logger.info(f"Descriptor CSV saved to {csv_path}")
return csv_path
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
# Main # Main
# ────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser(description="Token-level Feature Importance") parser = argparse.ArgumentParser(
parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "final" / "model.pt"), description="Biodistribution feature importance (token-level & descriptor-level)")
help="Path to trained model checkpoint") parser.add_argument("--model-path", type=str,
parser.add_argument("--data-path", type=str, default=str(INTERIM_DATA_DIR / "internal.csv"), default=str(MODELS_DIR / "final" / "model.pt"))
help="Path to data (.csv or .parquet) for computing importance") parser.add_argument("--data-path", type=str,
parser.add_argument("--task", type=str, default="all", choices=TASKS + ["all"], default=str(INTERIM_DATA_DIR / "internal.csv"))
help="Target task for importance computation ('all' to run on all tasks)") parser.add_argument("--organ", type=str, default="all",
choices=BIODIST_ORGANS + ["all"],
help="Organ to analyze (default: all 7 organs)")
parser.add_argument("--method", type=str, default="ig", parser.add_argument("--method", type=str, default="ig",
choices=["ig", "ablation", "attention", "all"], choices=["ig", "ablation", "attention", "all"],
help="Which method(s) to run") help="Token-level method(s)")
parser.add_argument("--desc-ig", action="store_true",
help="Also compute descriptor-level IG within the desc token")
parser.add_argument("--desc-top-k", type=int, default=30,
help="Top-K descriptors to show in plot")
parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--n-steps", type=int, default=50, parser.add_argument("--n-steps", type=int, default=50,
help="Number of interpolation steps for Integrated Gradients") help="Interpolation steps for Integrated Gradients")
parser.add_argument("--device", type=str, parser.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu") default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--output-dir", type=str, parser.add_argument("--output-dir", type=str,
@ -352,17 +423,21 @@ def main() -> None:
device = torch.device(args.device) device = torch.device(args.device)
out_dir = Path(args.output_dir) out_dir = Path(args.output_dir)
# ── Resolve organs ──
organs = BIODIST_ORGANS if args.organ == "all" else [args.organ]
organ_indices = {organ: BIODIST_ORGANS.index(organ) for organ in organs}
logger.info(f"Device: {device}") logger.info(f"Device: {device}")
logger.info(f"Model: {args.model_path}") logger.info(f"Model: {args.model_path}")
logger.info(f"Data: {args.data_path}") logger.info(f"Data: {args.data_path}")
logger.info(f"Task: {args.task}") logger.info(f"Organs: {organs}")
# ── Load model ── # ── Load model & data ──
model = load_model(Path(args.model_path), device) model = load_model(Path(args.model_path), device)
token_names = get_token_names(model) token_names = get_token_names(model)
logger.info(f"Tokens ({len(token_names)}): {token_names}") logger.info(f"Tokens ({len(token_names)}): {token_names}")
# ── Load data ──
data_path = Path(args.data_path) data_path = Path(args.data_path)
if data_path.suffix == ".csv": if data_path.suffix == ".csv":
df = pd.read_csv(data_path) df = pd.read_csv(data_path)
@ -373,76 +448,78 @@ def main() -> None:
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
logger.info(f"Samples: {len(dataset)}") logger.info(f"Samples: {len(dataset)}")
# ── Pre-compute projected tokens ── # ── Pre-compute ──
all_tokens = precompute_projected_tokens(model, loader, device) all_tokens = precompute_projected_tokens(model, loader, device)
logger.info(f"Projected tokens shape: {all_tokens.shape}") logger.info(f"Projected tokens shape: {all_tokens.shape}")
# ── Gate values (always) ──
gv = gate_values(model) gv = gate_values(model)
logger.info(f"TokenProjector gate values: {gv}") logger.info(f"Gate values: {gv}")
# ── Run selected methods ── methods = ["ig", "ablation", "attention"] if args.method == "all" else [args.method]
methods = (
["ig", "ablation", "attention"] if args.method == "all"
else [args.method]
)
# Determine tasks to process # ── Pre-compute desc features if needed ──
tasks_to_run = TASKS if args.task == "all" else [args.task] raw_desc = None
desc_names = None
if args.desc_ig:
desc_names, desc_dim = get_desc_names_and_dim(model)
raw_desc = precompute_raw_desc(model, loader, device)
logger.info(f"Raw desc shape: {raw_desc.shape} (dim={desc_dim})")
assert raw_desc.size(1) == desc_dim
for task in tasks_to_run: # ── Per-organ loop ──
for organ, organ_idx in organ_indices.items():
logger.info(f"\n{'#'*60}") logger.info(f"\n{'#'*60}")
logger.info(f"# Processing task: {task}") logger.info(f"# Organ: {organ} (index={organ_idx})")
logger.info(f"{'#'*60}") logger.info(f"{'#'*60}")
# Token-level
results: Dict[str, np.ndarray] = {} results: Dict[str, np.ndarray] = {}
for method in methods: for method in methods:
logger.info(f"\n{'='*60}") logger.info(f" Computing: {method}")
logger.info(f"Computing: {method} (task={task})")
logger.info(f"{'='*60}")
if method == "ig": if method == "ig":
imp = integrated_gradients_importance( results["Integrated Gradients"] = integrated_gradients_importance(
model, all_tokens, device, model, all_tokens, device, organ_idx,
task=task, batch_size=args.batch_size, n_steps=args.n_steps, batch_size=args.batch_size, n_steps=args.n_steps,
) )
results["Integrated Gradients"] = imp
elif method == "ablation": elif method == "ablation":
imp = token_ablation_importance( results["Token Ablation"] = token_ablation_importance(
model, all_tokens, device, model, all_tokens, device, organ_idx,
task=task, batch_size=args.batch_size,
)
results["Token Ablation"] = imp
elif method == "attention":
imp = fusion_attention_importance(
model, all_tokens, device,
batch_size=args.batch_size, batch_size=args.batch_size,
) )
elif method == "attention":
imp = fusion_attention_importance(model, all_tokens, device, args.batch_size)
if imp is not None: if imp is not None:
results["Fusion Attention"] = imp results["Fusion Attention"] = imp
# ── Print summary ── # Print token summary
logger.info(f"\n{'='*60}")
logger.info(f"Token Importance Summary (task={task})")
logger.info(f"{'='*60}")
for method_name, importance in results.items(): for method_name, importance in results.items():
normed = normalize(importance) normed = normalize(importance)
order = np.argsort(-normed) order = np.argsort(-normed)
logger.info(f"\n {method_name}:") logger.info(f"\n {method_name} (biodist_{organ}):")
for rank, idx in enumerate(order, 1): for rank, idx in enumerate(order, 1):
logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}") logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}")
logger.info(f"\n Gate values (sigmoid):")
for name, val in sorted(gv.items(), key=lambda x: -x[1]):
logger.info(f" {name:<10s} {val:.4f}")
# ── Save results ──
if results: if results:
plot_token_importance(results, token_names, task, out_dir) plot_token_importance(results, token_names, organ, out_dir)
save_csv(results, token_names, task, out_dir, gate_vals=gv) save_token_csv(results, token_names, organ, out_dir, gate_vals=gv)
# Descriptor-level
if args.desc_ig and raw_desc is not None and desc_names is not None:
logger.info(f"\n Descriptor IG for {organ}...")
desc_imp = descriptor_ig_importance(
model, all_tokens, raw_desc, device, organ_idx,
batch_size=args.batch_size, n_steps=args.n_steps,
)
normed = normalize(desc_imp)
top_k = min(args.desc_top_k, len(desc_names))
order = np.argsort(-normed)
logger.info(f"\n Top-{top_k} desc features (biodist_{organ}):")
for rank, idx in enumerate(order[:top_k], 1):
logger.info(f" {rank:>3d}. {desc_names[idx]:<30s} {normed[idx]:.6f}")
desc_out_dir = out_dir / "desc"
plot_descriptor_importance(desc_imp, desc_names, organ, desc_out_dir, top_k)
save_descriptor_csv(desc_imp, desc_names, organ, desc_out_dir)
logger.info("\nDone!") logger.info("\nDone!")

View File

@ -201,6 +201,39 @@ class LNPModel(nn.Module):
} }
return task_heads[task](fused) return task_heads[task](fused)
def forward_replacing_token(
self,
raw_feature: torch.Tensor,
feature_key: str,
base_projected: torch.Tensor,
task: Optional[str] = None,
) -> torch.Tensor:
"""
用原始特征替换 base_projected 中指定 token 的投影然后 forward
用于对单个 token 内部特征做 Captum 归因 desc 210
Args:
raw_feature: [B, input_dim] 某个 token 的原始特征
feature_key: token 名称 "desc"
base_projected: [B, n_tokens, d_model] 其他 token 已投影好的张量
task: 任务名
Returns:
对应任务的预测输出
"""
projected = self.token_projector.projectors[feature_key](raw_feature)
gate = torch.sigmoid(self.token_projector.weights[feature_key])
projected = projected * gate # [B, d_model]
token_order = list(self.token_projector.keys)
token_idx = token_order.index(feature_key)
stacked = base_projected.clone()
stacked[:, token_idx, :] = projected
return self.forward_from_projected(stacked, task=task)
def forward_backbone( def forward_backbone(
self, self,
smiles: List[str], smiles: List[str],

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,3.2013970842762367e-09,0.17906809140868585,0.5030443072319031
mpnn,3.109777150387161e-09,0.17394338920379962,0.5024935007095337
maccs,3.0657202874248063e-09,0.1714790968475434,0.5030479431152344
morgan,3.020539287877718e-09,0.16895192663283606,0.5045571327209473
help,1.937320535640997e-09,0.10836278088337024,0.49689680337905884
comp,1.876732953403087e-09,0.10497385335304418,0.5007365345954895
exp,1.6503372406931126e-09,0.09231055445232395,0.5002157688140869
phys,1.6274562664095572e-11,0.0009103072183966515,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 3.2013970842762367e-09 0.17906809140868585 0.5030443072319031
3 mpnn 3.109777150387161e-09 0.17394338920379962 0.5024935007095337
4 maccs 3.0657202874248063e-09 0.1714790968475434 0.5030479431152344
5 morgan 3.020539287877718e-09 0.16895192663283606 0.5045571327209473
6 help 1.937320535640997e-09 0.10836278088337024 0.49689680337905884
7 comp 1.876732953403087e-09 0.10497385335304418 0.5007365345954895
8 exp 1.6503372406931126e-09 0.09231055445232395 0.5002157688140869
9 phys 1.6274562664095572e-11 0.0009103072183966515 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,0.48309860429978513,0.1830685579048154,0.5030443072319031
mpnn,0.4800125818662255,0.1818991202961257,0.5024935007095337
morgan,0.4681746999619525,0.17741319558101734,0.5045571327209473
maccs,0.4642216987644718,0.1759152193455701,0.5030479431152344
help,0.2636187040391749,0.09989740304701922,0.49689680337905884
comp,0.2503036292270154,0.0948517011498054,0.5007365345954895
exp,0.22730758172642437,0.08613742788142016,0.5002157688140869
phys,0.0021569658208921167,0.0008173747942266034,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 0.48309860429978513 0.1830685579048154 0.5030443072319031
3 mpnn 0.4800125818662255 0.1818991202961257 0.5024935007095337
4 morgan 0.4681746999619525 0.17741319558101734 0.5045571327209473
5 maccs 0.4642216987644718 0.1759152193455701 0.5030479431152344
6 help 0.2636187040391749 0.09989740304701922 0.49689680337905884
7 comp 0.2503036292270154 0.0948517011498054 0.5007365345954895
8 exp 0.22730758172642437 0.08613742788142016 0.5002157688140869
9 phys 0.0021569658208921167 0.0008173747942266034 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 56 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,0.28693426746338735,0.2025623449841757,0.5030443072319031
mpnn,0.2846832493908164,0.2009732301551499,0.5024935007095337
morgan,0.2646961202937205,0.18686323982460912,0.5045571327209473
maccs,0.25845640337994125,0.1824582877731647,0.5030479431152344
help,0.11619691669315178,0.082029658337337,0.49689680337905884
comp,0.10812802919624785,0.07633339630758515,0.5007365345954895
exp,0.09640860187977682,0.06806002171178546,0.5002157688140869
phys,0.001019643036021533,0.000719820906192951,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 0.28693426746338735 0.2025623449841757 0.5030443072319031
3 mpnn 0.2846832493908164 0.2009732301551499 0.5024935007095337
4 morgan 0.2646961202937205 0.18686323982460912 0.5045571327209473
5 maccs 0.25845640337994125 0.1824582877731647 0.5030479431152344
6 help 0.11619691669315178 0.082029658337337 0.49689680337905884
7 comp 0.10812802919624785 0.07633339630758515 0.5007365345954895
8 exp 0.09640860187977682 0.06806002171178546 0.5002157688140869
9 phys 0.001019643036021533 0.000719820906192951 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 54 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
mpnn,0.29288250773521896,0.2140725741095052,0.5024935007095337
desc,0.2764989421166159,0.2020975603328644,0.5030443072319031
morgan,0.2578500855439716,0.18846680866532511,0.5045571327209473
maccs,0.2471777274240024,0.18066620905892686,0.5030479431152344
help,0.10363027887841153,0.07574505123823679,0.49689680337905884
comp,0.09781463069792998,0.07149429968008479,0.5007365345954895
exp,0.091429982122356,0.06682765650659303,0.5002157688140869
phys,0.0008617135523839594,0.0006298404084638948,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 mpnn 0.29288250773521896 0.2140725741095052 0.5024935007095337
3 desc 0.2764989421166159 0.2020975603328644 0.5030443072319031
4 morgan 0.2578500855439716 0.18846680866532511 0.5045571327209473
5 maccs 0.2471777274240024 0.18066620905892686 0.5030479431152344
6 help 0.10363027887841153 0.07574505123823679 0.49689680337905884
7 comp 0.09781463069792998 0.07149429968008479 0.5007365345954895
8 exp 0.091429982122356 0.06682765650659303 0.5002157688140869
9 phys 0.0008617135523839594 0.0006298404084638948 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 54 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,0.43124767207505,0.1973483310882186,0.5030443072319031
mpnn,0.4250195774970647,0.1944982193069518,0.5024935007095337
morgan,0.41677666295297405,0.19072608200879151,0.5045571327209473
maccs,0.40606126033119555,0.1858224802938626,0.5030479431152344
help,0.17471509961619794,0.07995343640757792,0.49689680337905884
comp,0.16962298878652332,0.07762317554120124,0.5007365345954895
exp,0.16035562746707094,0.07338222907722322,0.5002157688140869
phys,0.0014117471939899774,0.0006460462761730848,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 0.43124767207505 0.1973483310882186 0.5030443072319031
3 mpnn 0.4250195774970647 0.1944982193069518 0.5024935007095337
4 morgan 0.41677666295297405 0.19072608200879151 0.5045571327209473
5 maccs 0.40606126033119555 0.1858224802938626 0.5030479431152344
6 help 0.17471509961619794 0.07995343640757792 0.49689680337905884
7 comp 0.16962298878652332 0.07762317554120124 0.5007365345954895
8 exp 0.16035562746707094 0.07338222907722322 0.5002157688140869
9 phys 0.0014117471939899774 0.0006460462761730848 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
mpnn,0.17711989597355013,0.21006733816477163,0.5024935007095337
desc,0.17162019114321028,0.20354459068882486,0.5030443072319031
morgan,0.16095514462538474,0.19089565635488848,0.5045571327209473
maccs,0.15742516053846328,0.18670903261717847,0.5030479431152344
comp,0.05954181250470194,0.07061764571178654,0.5007365345954895
help,0.05892528920918569,0.06988643814815398,0.49689680337905884
exp,0.05708834082906645,0.06770778478774685,0.5002157688140869
phys,0.0004818760368551862,0.0005715135266490605,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 mpnn 0.17711989597355013 0.21006733816477163 0.5024935007095337
3 desc 0.17162019114321028 0.20354459068882486 0.5030443072319031
4 morgan 0.16095514462538474 0.19089565635488848 0.5045571327209473
5 maccs 0.15742516053846328 0.18670903261717847 0.5030479431152344
6 comp 0.05954181250470194 0.07061764571178654 0.5007365345954895
7 help 0.05892528920918569 0.06988643814815398 0.49689680337905884
8 exp 0.05708834082906645 0.06770778478774685 0.5002157688140869
9 phys 0.0004818760368551862 0.0005715135266490605 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB