移除无关任务

This commit is contained in:
RYDE-WORK 2026-03-03 15:04:13 +08:00
parent 3b38727053
commit ac5f598484
15 changed files with 312 additions and 245 deletions

View File

@ -123,25 +123,36 @@ train: requirements
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG)
#################################################################################
# INTERPRETABILITY #
# INTERPRETABILITY (biodistribution feature importance) #
#################################################################################
# 参数:
# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic, all; 默认: delivery)
# 如果指定 'all',将依次计算所有 6 个任务
# METHOD 方法 (ig, ablation, attention, all; 默认: ig)
# DATA 数据路径 (默认: data/interim/internal.csv即最终模型的全量训练数据)
# ORGAN 器官 (lymph_nodes, heart, liver, spleen, lung, kidney, muscle, all; 默认: all)
# METHOD token 级方法 (ig, ablation, attention, all; 默认: ig)
# DESC_IG 同时计算 desc 内部特征 IG (1 启用; 默认不启用)
# DESC_TOP_K 可视化展示的 top-K 特征数 (默认: 30)
# DATA 数据路径 (默认: data/interim/internal.csv)
# MODEL 模型路径 (默认: models/final/model.pt)
TASK_FLAG = $(if $(TASK),--task $(TASK),)
METHOD_FLAG = $(if $(METHOD),--method $(METHOD),)
DATA_FLAG = $(if $(DATA),--data-path $(DATA),)
MODEL_FLAG = $(if $(MODEL),--model-path $(MODEL),)
ORGAN_FLAG = $(if $(ORGAN),--organ $(ORGAN),)
DESC_IG_FLAG = $(if $(DESC_IG),--desc-ig,)
DESC_TOP_K_FLAG = $(if $(DESC_TOP_K),--desc-top-k $(DESC_TOP_K),)
## Compute token-level feature importance
## Compute biodistribution feature importance (token-level)
.PHONY: feature_importance
feature_importance: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \
$(TASK_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG)
$(ORGAN_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG) \
$(DESC_IG_FLAG) $(DESC_TOP_K_FLAG)
## Compute biodistribution feature importance (token + descriptor-level)
.PHONY: desc_importance
desc_importance: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \
--desc-ig $(ORGAN_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) \
$(DEVICE_FLAG) $(DESC_TOP_K_FLAG)
#################################################################################
# SERVING & DEPLOYMENT #

View File

@ -1,25 +1,26 @@
"""
Token-level feature importance for the LNP multi-task model.
Biodistribution feature importance for the LNP multi-task model.
Three complementary methods (inspired by AGILE / Captum):
1. Integrated Gradients gradient-based attribution via Captum
2. Token Ablation zero-out each token and measure prediction change
3. Fusion Attention Weights extract learned attention pooling weights
Two levels of analysis, both targeting the biodist head (per organ):
1. Token-level which of the 8 input tokens matters most
2. Descriptor-level which RDKit descriptors inside the "desc" token matter most
Token-level methods:
- Integrated Gradients (Captum)
- Token Ablation (zero-out)
- Fusion Attention Weights
Usage:
python -m lnp_ml.interpretability.token_importance \
--model-path models/model.pt \
--data-path data/processed/train.parquet \
--task delivery \
--method all
python -m lnp_ml.interpretability.token_importance # all organs, token-level IG
python -m lnp_ml.interpretability.token_importance --organ liver # liver only
python -m lnp_ml.interpretability.token_importance --desc-ig # + desc-level IG
"""
from __future__ import annotations
import argparse
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
@ -35,7 +36,7 @@ from lnp_ml.modeling.predict import load_model
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
TASKS = ["delivery", "size", "pdi", "ee", "biodist", "toxic"]
BIODIST_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]:
@ -45,96 +46,97 @@ def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]:
# ──────────────────────────────────────────────────────────────────────
# Helper: pre-compute projected tokens for the whole dataset
# Pre-compute
# ──────────────────────────────────────────────────────────────────────
@torch.no_grad()
def precompute_projected_tokens(
model: nn.Module,
loader: DataLoader,
device: torch.device,
model: nn.Module, loader: DataLoader, device: torch.device,
) -> torch.Tensor:
"""
Run encoder + TokenProjector on all samples and stack results.
Returns:
all_tokens: [N, n_tokens, d_model]
"""
"""Returns [N, n_tokens, d_model]."""
model.eval()
chunks = []
for batch in tqdm(loader, desc="Encoding tokens"):
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
stacked = model._encode_and_project(smiles, tabular)
chunks.append(stacked.cpu())
chunks.append(model._encode_and_project(smiles, tabular).cpu())
return torch.cat(chunks, dim=0)
@torch.no_grad()
def precompute_raw_desc(
model: nn.Module, loader: DataLoader, device: torch.device,
) -> torch.Tensor:
"""Returns [N, desc_dim] raw RDKit descriptor features."""
model.eval()
chunks = []
for batch in tqdm(loader, desc="Encoding raw desc"):
smiles = batch["smiles"]
rdkit_features = model.rdkit_encoder(smiles)
target_device = next(model.parameters()).device
chunks.append(rdkit_features["desc"].to(target_device).cpu())
return torch.cat(chunks, dim=0)
def get_desc_names_and_dim(model: nn.Module) -> tuple[List[str], int]:
"""Get descriptor names from RDKit and verify against model dim."""
from rdkit.Chem import Descriptors as RDKitDesc
names = [name for name, _ in RDKitDesc.descList]
model_dim = model.token_projector.projectors["desc"][1].in_features
if len(names) != model_dim:
logger.warning(
f"RDKit descList has {len(names)} names but model expects "
f"{model_dim} features; using generic names."
)
names = [f"desc_{i}" for i in range(model_dim)]
return names, model_dim
# ──────────────────────────────────────────────────────────────────────
# Method 1: Integrated Gradients (Captum)
# Token-level: Integrated Gradients
# ──────────────────────────────────────────────────────────────────────
class _ProjectedWrapper(nn.Module):
"""Wraps model.forward_from_projected so Captum can attribute w.r.t. stacked tokens."""
def __init__(self, model: nn.Module, task: str) -> None:
def __init__(self, model: nn.Module, organ_index: int) -> None:
super().__init__()
self.model = model
self.task = task
self.organ_index = organ_index
def forward(self, stacked: torch.Tensor) -> torch.Tensor:
"""
Args:
stacked: [B, n_tokens, d_model]
Returns:
[B] scalar predictions (sum over output dims for non-scalar heads)
"""
out = self.model.forward_from_projected(stacked, task=self.task)
if out.dim() > 1 and out.size(-1) > 1:
return out.sum(dim=-1)
return out.squeeze(-1)
out = self.model.forward_from_projected(stacked, task="biodist")
return out[:, self.organ_index]
def integrated_gradients_importance(
model: nn.Module,
all_tokens: torch.Tensor,
device: torch.device,
task: str = "delivery",
organ_index: int,
batch_size: int = 64,
n_steps: int = 50,
) -> np.ndarray:
"""
Compute token-level Integrated Gradients.
Returns:
importance: [n_tokens] averaged absolute attribution per token position
"""
"""Returns [n_tokens] averaged absolute attribution per token."""
from captum.attr import IntegratedGradients
wrapper = _ProjectedWrapper(model, task).to(device)
wrapper = _ProjectedWrapper(model, organ_index).to(device)
wrapper.eval()
ig = IntegratedGradients(wrapper)
N = all_tokens.size(0)
all_attrs: List[torch.Tensor] = []
for start in tqdm(range(0, N, batch_size), desc=f"IG ({task})"):
for start in tqdm(range(0, N, batch_size), desc=f"IG (organ={organ_index})"):
end = min(start + batch_size, N)
inp = all_tokens[start:end].to(device).requires_grad_(True)
baseline = torch.zeros_like(inp)
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
# attr: [B, n_tokens, d_model] → per-token L2 norm → [B, n_tokens]
token_attr = attr.detach().cpu().norm(dim=-1)
all_attrs.append(token_attr)
all_attrs.append(attr.detach().cpu().norm(dim=-1)) # [B, n_tokens]
attrs = torch.cat(all_attrs, dim=0) # [N, n_tokens]
importance = attrs.mean(dim=0).numpy() # [n_tokens]
return importance
return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
# ──────────────────────────────────────────────────────────────────────
# Method 2: Token Ablation (zero-out)
# Token-level: Ablation
# ──────────────────────────────────────────────────────────────────────
@torch.no_grad()
@ -142,98 +144,113 @@ def token_ablation_importance(
model: nn.Module,
all_tokens: torch.Tensor,
device: torch.device,
task: str = "delivery",
organ_index: int,
batch_size: int = 64,
) -> np.ndarray:
"""
For each token position, replace it with zeros and measure
the average absolute prediction change.
Returns:
importance: [n_tokens]
"""
"""Returns [n_tokens] importance via zero-out ablation."""
model.eval()
n_tokens = all_tokens.size(1)
N = all_tokens.size(0)
# original predictions
orig_preds = _batch_predict(model, all_tokens, device, task, batch_size)
orig_preds = _batch_predict(model, all_tokens, device, organ_index, batch_size)
importance = np.zeros(n_tokens)
for t in range(n_tokens):
ablated = all_tokens.clone()
ablated[:, t, :] = 0.0
abl_preds = _batch_predict(model, ablated, device, task, batch_size)
abl_preds = _batch_predict(model, ablated, device, organ_index, batch_size)
importance[t] = np.abs(orig_preds - abl_preds).mean()
return importance
def _batch_predict(
model: nn.Module,
all_tokens: torch.Tensor,
device: torch.device,
task: str,
batch_size: int,
model: nn.Module, all_tokens: torch.Tensor,
device: torch.device, organ_index: int, batch_size: int,
) -> np.ndarray:
"""Run forward_from_projected in batches, return [N] predictions."""
model.eval()
preds = []
N = all_tokens.size(0)
for start in range(0, N, batch_size):
end = min(start + batch_size, N)
inp = all_tokens[start:end].to(device)
out = model.forward_from_projected(inp, task=task)
if out.dim() > 1 and out.size(-1) > 1:
out = out.sum(dim=-1)
else:
out = out.squeeze(-1)
preds.append(out.cpu().numpy())
return np.concatenate(preds, axis=0)
for start in range(0, all_tokens.size(0), batch_size):
end = min(start + batch_size, all_tokens.size(0))
out = model.forward_from_projected(all_tokens[start:end].to(device), task="biodist")
preds.append(out[:, organ_index].cpu().numpy())
return np.concatenate(preds)
# ──────────────────────────────────────────────────────────────────────
# Method 3: Fusion Attention Weights
# Token-level: Fusion Attention Weights (task-agnostic)
# ──────────────────────────────────────────────────────────────────────
@torch.no_grad()
def fusion_attention_importance(
model: nn.Module,
all_tokens: torch.Tensor,
device: torch.device,
batch_size: int = 64,
model: nn.Module, all_tokens: torch.Tensor,
device: torch.device, batch_size: int = 64,
) -> Optional[np.ndarray]:
"""
Extract FusionLayer attention weights (only works if strategy="attention").
Returns:
importance: [n_tokens] averaged attention weights, or None if not applicable.
"""
model.eval()
if model.fusion.strategy != "attention":
logger.warning("Fusion strategy is not 'attention'; skipping attention weight extraction.")
logger.warning("Fusion strategy is not 'attention'; skipping.")
return None
N = all_tokens.size(0)
all_weights: List[torch.Tensor] = []
for start in range(0, N, batch_size):
end = min(start + batch_size, N)
inp = all_tokens[start:end].to(device)
attended = model.cross_attention(inp)
for start in range(0, all_tokens.size(0), batch_size):
end = min(start + batch_size, all_tokens.size(0))
attended = model.cross_attention(all_tokens[start:end].to(device))
_, weights = model.fusion(attended, return_attn_weights=True)
all_weights.append(weights.cpu())
weights = torch.cat(all_weights, dim=0) # [N, n_tokens]
return weights.mean(dim=0).numpy()
return torch.cat(all_weights, dim=0).mean(dim=0).numpy()
# ──────────────────────────────────────────────────────────────────────
# Bonus: TokenProjector gate values (static)
# Descriptor-level: Integrated Gradients
# ──────────────────────────────────────────────────────────────────────
class _ReplacingDescWrapper(nn.Module):
def __init__(self, model: nn.Module, base_projected: torch.Tensor, organ_index: int) -> None:
super().__init__()
self.model = model
self.base_projected = base_projected
self.organ_index = organ_index
def forward(self, raw_desc: torch.Tensor) -> torch.Tensor:
base = self.base_projected[:raw_desc.size(0)].to(raw_desc.device)
out = self.model.forward_replacing_token(raw_desc, "desc", base, task="biodist")
return out[:, self.organ_index]
def descriptor_ig_importance(
model: nn.Module,
all_tokens: torch.Tensor,
raw_desc: torch.Tensor,
device: torch.device,
organ_index: int,
batch_size: int = 64,
n_steps: int = 50,
) -> np.ndarray:
"""Returns [desc_dim] averaged absolute attribution per descriptor."""
from captum.attr import IntegratedGradients
wrapper = _ReplacingDescWrapper(model, all_tokens, organ_index).to(device)
wrapper.eval()
ig = IntegratedGradients(wrapper)
N = raw_desc.size(0)
all_attrs: List[torch.Tensor] = []
for start in tqdm(range(0, N, batch_size), desc=f"Desc-IG (organ={organ_index})"):
end = min(start + batch_size, N)
inp = raw_desc[start:end].to(device).requires_grad_(True)
baseline = torch.zeros_like(inp)
wrapper.base_projected = all_tokens[start:end]
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
all_attrs.append(attr.detach().cpu().abs())
return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
# ──────────────────────────────────────────────────────────────────────
# Gate values
# ──────────────────────────────────────────────────────────────────────
def gate_values(model: nn.Module) -> Dict[str, float]:
"""Read sigmoid(weight) from TokenProjector for each token."""
gates = {}
for key in model.token_projector.keys:
w = model.token_projector.weights[key].detach().cpu()
@ -242,7 +259,7 @@ def gate_values(model: nn.Module) -> Dict[str, float]:
# ──────────────────────────────────────────────────────────────────────
# Visualization
# Visualization & IO
# ──────────────────────────────────────────────────────────────────────
def normalize(arr: np.ndarray) -> np.ndarray:
@ -251,10 +268,8 @@ def normalize(arr: np.ndarray) -> np.ndarray:
def plot_token_importance(
results: Dict[str, np.ndarray],
token_names: List[str],
task: str,
out_dir: Path,
results: Dict[str, np.ndarray], token_names: List[str],
organ: str, out_dir: Path,
) -> Path:
import matplotlib
matplotlib.use("Agg")
@ -265,20 +280,18 @@ def plot_token_importance(
if n_methods == 1:
axes = [axes]
channel_a_color = "#5a448e"
channel_b_color = "#e07b39"
color_a, color_b = "#5a448e", "#e07b39"
for ax, (method_name, importance) in zip(axes, results.items()):
normed = normalize(importance)
order = np.argsort(normed)
names_sorted = [token_names[i] for i in order]
vals_sorted = normed[order]
n_tokens = len(token_names)
split_idx = 4 if n_tokens == 8 else 3
channel_a_set = set(token_names[:split_idx])
colors = [channel_a_color if n in channel_a_set else channel_b_color for n in names_sorted]
colors = [color_a if n in channel_a_set else color_b for n in names_sorted]
ax.barh(range(len(names_sorted)), vals_sorted, color=colors)
ax.set_yticks(range(len(names_sorted)))
@ -288,22 +301,20 @@ def plot_token_importance(
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
fig.suptitle(f"Token Importance — task: {task}", fontsize=14, y=1.02)
fig.suptitle(f"Token Importance — biodist: {organ}", fontsize=14, y=1.02)
fig.tight_layout()
out_dir.mkdir(parents=True, exist_ok=True)
fig_path = out_dir / f"token_importance_{task}.png"
fig_path = out_dir / f"token_importance_biodist_{organ}.png"
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
plt.close(fig)
logger.info(f"Figure saved to {fig_path}")
return fig_path
def save_csv(
results: Dict[str, np.ndarray],
token_names: List[str],
task: str,
out_dir: Path,
def save_token_csv(
results: Dict[str, np.ndarray], token_names: List[str],
organ: str, out_dir: Path,
gate_vals: Optional[Dict[str, float]] = None,
) -> Path:
out_dir.mkdir(parents=True, exist_ok=True)
@ -312,38 +323,98 @@ def save_csv(
normed = normalize(importance)
df[f"{method_name}_raw"] = importance
df[f"{method_name}_normalized"] = normed
if gate_vals is not None:
df["gate_sigmoid"] = [gate_vals.get(t, float("nan")) for t in token_names]
df = df.sort_values(
by=[c for c in df.columns if c.endswith("_normalized")][-1],
ascending=False,
)
csv_path = out_dir / f"token_importance_{task}.csv"
csv_path = out_dir / f"token_importance_biodist_{organ}.csv"
df.to_csv(csv_path, index=False)
logger.info(f"CSV saved to {csv_path}")
return csv_path
def plot_descriptor_importance(
importance: np.ndarray, feature_names: List[str],
organ: str, out_dir: Path, top_k: int = 30,
) -> Path:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
normed = normalize(importance)
order = np.argsort(-normed)
top_indices = order[:top_k]
names = [feature_names[i] for i in reversed(top_indices)]
vals = [normed[i] for i in reversed(top_indices)]
fig, ax = plt.subplots(figsize=(8, max(5, top_k * 0.35)))
ax.barh(range(len(names)), vals, color="#5a448e")
ax.set_yticks(range(len(names)))
ax.set_yticklabels(names, fontsize=9)
ax.set_xlabel("Normalized Importance", fontsize=11)
ax.set_title(
f"Top-{top_k} desc Feature Importance — biodist: {organ}\n"
f"(total features: {len(feature_names)})",
fontsize=13,
)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
fig.tight_layout()
out_dir.mkdir(parents=True, exist_ok=True)
fig_path = out_dir / f"desc_importance_biodist_{organ}.png"
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
plt.close(fig)
logger.info(f"Descriptor figure saved to {fig_path}")
return fig_path
def save_descriptor_csv(
importance: np.ndarray, feature_names: List[str],
organ: str, out_dir: Path,
) -> Path:
out_dir.mkdir(parents=True, exist_ok=True)
normed = normalize(importance)
df = pd.DataFrame({
"feature": feature_names,
"ig_raw": importance,
"ig_normalized": normed,
})
df = df.sort_values("ig_normalized", ascending=False).reset_index(drop=True)
df.index.name = "rank"
csv_path = out_dir / f"desc_importance_biodist_{organ}.csv"
df.to_csv(csv_path)
logger.info(f"Descriptor CSV saved to {csv_path}")
return csv_path
# ──────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(description="Token-level Feature Importance")
parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "final" / "model.pt"),
help="Path to trained model checkpoint")
parser.add_argument("--data-path", type=str, default=str(INTERIM_DATA_DIR / "internal.csv"),
help="Path to data (.csv or .parquet) for computing importance")
parser.add_argument("--task", type=str, default="all", choices=TASKS + ["all"],
help="Target task for importance computation ('all' to run on all tasks)")
parser = argparse.ArgumentParser(
description="Biodistribution feature importance (token-level & descriptor-level)")
parser.add_argument("--model-path", type=str,
default=str(MODELS_DIR / "final" / "model.pt"))
parser.add_argument("--data-path", type=str,
default=str(INTERIM_DATA_DIR / "internal.csv"))
parser.add_argument("--organ", type=str, default="all",
choices=BIODIST_ORGANS + ["all"],
help="Organ to analyze (default: all 7 organs)")
parser.add_argument("--method", type=str, default="ig",
choices=["ig", "ablation", "attention", "all"],
help="Which method(s) to run")
help="Token-level method(s)")
parser.add_argument("--desc-ig", action="store_true",
help="Also compute descriptor-level IG within the desc token")
parser.add_argument("--desc-top-k", type=int, default=30,
help="Top-K descriptors to show in plot")
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--n-steps", type=int, default=50,
help="Number of interpolation steps for Integrated Gradients")
help="Interpolation steps for Integrated Gradients")
parser.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--output-dir", type=str,
@ -352,17 +423,21 @@ def main() -> None:
device = torch.device(args.device)
out_dir = Path(args.output_dir)
# ── Resolve organs ──
organs = BIODIST_ORGANS if args.organ == "all" else [args.organ]
organ_indices = {organ: BIODIST_ORGANS.index(organ) for organ in organs}
logger.info(f"Device: {device}")
logger.info(f"Model: {args.model_path}")
logger.info(f"Data: {args.data_path}")
logger.info(f"Task: {args.task}")
logger.info(f"Organs: {organs}")
# ── Load model ──
# ── Load model & data ──
model = load_model(Path(args.model_path), device)
token_names = get_token_names(model)
logger.info(f"Tokens ({len(token_names)}): {token_names}")
# ── Load data ──
data_path = Path(args.data_path)
if data_path.suffix == ".csv":
df = pd.read_csv(data_path)
@ -373,76 +448,78 @@ def main() -> None:
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
logger.info(f"Samples: {len(dataset)}")
# ── Pre-compute projected tokens ──
# ── Pre-compute ──
all_tokens = precompute_projected_tokens(model, loader, device)
logger.info(f"Projected tokens shape: {all_tokens.shape}")
# ── Gate values (always) ──
gv = gate_values(model)
logger.info(f"TokenProjector gate values: {gv}")
logger.info(f"Gate values: {gv}")
# ── Run selected methods ──
methods = (
["ig", "ablation", "attention"] if args.method == "all"
else [args.method]
)
methods = ["ig", "ablation", "attention"] if args.method == "all" else [args.method]
# Determine tasks to process
tasks_to_run = TASKS if args.task == "all" else [args.task]
# ── Pre-compute desc features if needed ──
raw_desc = None
desc_names = None
if args.desc_ig:
desc_names, desc_dim = get_desc_names_and_dim(model)
raw_desc = precompute_raw_desc(model, loader, device)
logger.info(f"Raw desc shape: {raw_desc.shape} (dim={desc_dim})")
assert raw_desc.size(1) == desc_dim
for task in tasks_to_run:
# ── Per-organ loop ──
for organ, organ_idx in organ_indices.items():
logger.info(f"\n{'#'*60}")
logger.info(f"# Processing task: {task}")
logger.info(f"# Organ: {organ} (index={organ_idx})")
logger.info(f"{'#'*60}")
# Token-level
results: Dict[str, np.ndarray] = {}
for method in methods:
logger.info(f"\n{'='*60}")
logger.info(f"Computing: {method} (task={task})")
logger.info(f"{'='*60}")
logger.info(f" Computing: {method}")
if method == "ig":
imp = integrated_gradients_importance(
model, all_tokens, device,
task=task, batch_size=args.batch_size, n_steps=args.n_steps,
results["Integrated Gradients"] = integrated_gradients_importance(
model, all_tokens, device, organ_idx,
batch_size=args.batch_size, n_steps=args.n_steps,
)
results["Integrated Gradients"] = imp
elif method == "ablation":
imp = token_ablation_importance(
model, all_tokens, device,
task=task, batch_size=args.batch_size,
)
results["Token Ablation"] = imp
elif method == "attention":
imp = fusion_attention_importance(
model, all_tokens, device,
results["Token Ablation"] = token_ablation_importance(
model, all_tokens, device, organ_idx,
batch_size=args.batch_size,
)
elif method == "attention":
imp = fusion_attention_importance(model, all_tokens, device, args.batch_size)
if imp is not None:
results["Fusion Attention"] = imp
# ── Print summary ──
logger.info(f"\n{'='*60}")
logger.info(f"Token Importance Summary (task={task})")
logger.info(f"{'='*60}")
# Print token summary
for method_name, importance in results.items():
normed = normalize(importance)
order = np.argsort(-normed)
logger.info(f"\n {method_name}:")
logger.info(f"\n {method_name} (biodist_{organ}):")
for rank, idx in enumerate(order, 1):
logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}")
logger.info(f"\n Gate values (sigmoid):")
for name, val in sorted(gv.items(), key=lambda x: -x[1]):
logger.info(f" {name:<10s} {val:.4f}")
# ── Save results ──
if results:
plot_token_importance(results, token_names, task, out_dir)
save_csv(results, token_names, task, out_dir, gate_vals=gv)
plot_token_importance(results, token_names, organ, out_dir)
save_token_csv(results, token_names, organ, out_dir, gate_vals=gv)
# Descriptor-level
if args.desc_ig and raw_desc is not None and desc_names is not None:
logger.info(f"\n Descriptor IG for {organ}...")
desc_imp = descriptor_ig_importance(
model, all_tokens, raw_desc, device, organ_idx,
batch_size=args.batch_size, n_steps=args.n_steps,
)
normed = normalize(desc_imp)
top_k = min(args.desc_top_k, len(desc_names))
order = np.argsort(-normed)
logger.info(f"\n Top-{top_k} desc features (biodist_{organ}):")
for rank, idx in enumerate(order[:top_k], 1):
logger.info(f" {rank:>3d}. {desc_names[idx]:<30s} {normed[idx]:.6f}")
desc_out_dir = out_dir / "desc"
plot_descriptor_importance(desc_imp, desc_names, organ, desc_out_dir, top_k)
save_descriptor_csv(desc_imp, desc_names, organ, desc_out_dir)
logger.info("\nDone!")

View File

@ -201,6 +201,39 @@ class LNPModel(nn.Module):
}
return task_heads[task](fused)
def forward_replacing_token(
self,
raw_feature: torch.Tensor,
feature_key: str,
base_projected: torch.Tensor,
task: Optional[str] = None,
) -> torch.Tensor:
"""
用原始特征替换 base_projected 中指定 token 的投影然后 forward
用于对单个 token 内部特征做 Captum 归因 desc 210
Args:
raw_feature: [B, input_dim] 某个 token 的原始特征
feature_key: token 名称 "desc"
base_projected: [B, n_tokens, d_model] 其他 token 已投影好的张量
task: 任务名
Returns:
对应任务的预测输出
"""
projected = self.token_projector.projectors[feature_key](raw_feature)
gate = torch.sigmoid(self.token_projector.weights[feature_key])
projected = projected * gate # [B, d_model]
token_order = list(self.token_projector.keys)
token_idx = token_order.index(feature_key)
stacked = base_projected.clone()
stacked[:, token_idx, :] = projected
return self.forward_from_projected(stacked, task=task)
def forward_backbone(
self,
smiles: List[str],

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,3.2013970842762367e-09,0.17906809140868585,0.5030443072319031
mpnn,3.109777150387161e-09,0.17394338920379962,0.5024935007095337
maccs,3.0657202874248063e-09,0.1714790968475434,0.5030479431152344
morgan,3.020539287877718e-09,0.16895192663283606,0.5045571327209473
help,1.937320535640997e-09,0.10836278088337024,0.49689680337905884
comp,1.876732953403087e-09,0.10497385335304418,0.5007365345954895
exp,1.6503372406931126e-09,0.09231055445232395,0.5002157688140869
phys,1.6274562664095572e-11,0.0009103072183966515,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 3.2013970842762367e-09 0.17906809140868585 0.5030443072319031
3 mpnn 3.109777150387161e-09 0.17394338920379962 0.5024935007095337
4 maccs 3.0657202874248063e-09 0.1714790968475434 0.5030479431152344
5 morgan 3.020539287877718e-09 0.16895192663283606 0.5045571327209473
6 help 1.937320535640997e-09 0.10836278088337024 0.49689680337905884
7 comp 1.876732953403087e-09 0.10497385335304418 0.5007365345954895
8 exp 1.6503372406931126e-09 0.09231055445232395 0.5002157688140869
9 phys 1.6274562664095572e-11 0.0009103072183966515 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,0.48309860429978513,0.1830685579048154,0.5030443072319031
mpnn,0.4800125818662255,0.1818991202961257,0.5024935007095337
morgan,0.4681746999619525,0.17741319558101734,0.5045571327209473
maccs,0.4642216987644718,0.1759152193455701,0.5030479431152344
help,0.2636187040391749,0.09989740304701922,0.49689680337905884
comp,0.2503036292270154,0.0948517011498054,0.5007365345954895
exp,0.22730758172642437,0.08613742788142016,0.5002157688140869
phys,0.0021569658208921167,0.0008173747942266034,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 0.48309860429978513 0.1830685579048154 0.5030443072319031
3 mpnn 0.4800125818662255 0.1818991202961257 0.5024935007095337
4 morgan 0.4681746999619525 0.17741319558101734 0.5045571327209473
5 maccs 0.4642216987644718 0.1759152193455701 0.5030479431152344
6 help 0.2636187040391749 0.09989740304701922 0.49689680337905884
7 comp 0.2503036292270154 0.0948517011498054 0.5007365345954895
8 exp 0.22730758172642437 0.08613742788142016 0.5002157688140869
9 phys 0.0021569658208921167 0.0008173747942266034 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 56 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,0.28693426746338735,0.2025623449841757,0.5030443072319031
mpnn,0.2846832493908164,0.2009732301551499,0.5024935007095337
morgan,0.2646961202937205,0.18686323982460912,0.5045571327209473
maccs,0.25845640337994125,0.1824582877731647,0.5030479431152344
help,0.11619691669315178,0.082029658337337,0.49689680337905884
comp,0.10812802919624785,0.07633339630758515,0.5007365345954895
exp,0.09640860187977682,0.06806002171178546,0.5002157688140869
phys,0.001019643036021533,0.000719820906192951,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 0.28693426746338735 0.2025623449841757 0.5030443072319031
3 mpnn 0.2846832493908164 0.2009732301551499 0.5024935007095337
4 morgan 0.2646961202937205 0.18686323982460912 0.5045571327209473
5 maccs 0.25845640337994125 0.1824582877731647 0.5030479431152344
6 help 0.11619691669315178 0.082029658337337 0.49689680337905884
7 comp 0.10812802919624785 0.07633339630758515 0.5007365345954895
8 exp 0.09640860187977682 0.06806002171178546 0.5002157688140869
9 phys 0.001019643036021533 0.000719820906192951 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 54 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
mpnn,0.29288250773521896,0.2140725741095052,0.5024935007095337
desc,0.2764989421166159,0.2020975603328644,0.5030443072319031
morgan,0.2578500855439716,0.18846680866532511,0.5045571327209473
maccs,0.2471777274240024,0.18066620905892686,0.5030479431152344
help,0.10363027887841153,0.07574505123823679,0.49689680337905884
comp,0.09781463069792998,0.07149429968008479,0.5007365345954895
exp,0.091429982122356,0.06682765650659303,0.5002157688140869
phys,0.0008617135523839594,0.0006298404084638948,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 mpnn 0.29288250773521896 0.2140725741095052 0.5024935007095337
3 desc 0.2764989421166159 0.2020975603328644 0.5030443072319031
4 morgan 0.2578500855439716 0.18846680866532511 0.5045571327209473
5 maccs 0.2471777274240024 0.18066620905892686 0.5030479431152344
6 help 0.10363027887841153 0.07574505123823679 0.49689680337905884
7 comp 0.09781463069792998 0.07149429968008479 0.5007365345954895
8 exp 0.091429982122356 0.06682765650659303 0.5002157688140869
9 phys 0.0008617135523839594 0.0006298404084638948 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 54 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
desc,0.43124767207505,0.1973483310882186,0.5030443072319031
mpnn,0.4250195774970647,0.1944982193069518,0.5024935007095337
morgan,0.41677666295297405,0.19072608200879151,0.5045571327209473
maccs,0.40606126033119555,0.1858224802938626,0.5030479431152344
help,0.17471509961619794,0.07995343640757792,0.49689680337905884
comp,0.16962298878652332,0.07762317554120124,0.5007365345954895
exp,0.16035562746707094,0.07338222907722322,0.5002157688140869
phys,0.0014117471939899774,0.0006460462761730848,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 desc 0.43124767207505 0.1973483310882186 0.5030443072319031
3 mpnn 0.4250195774970647 0.1944982193069518 0.5024935007095337
4 morgan 0.41677666295297405 0.19072608200879151 0.5045571327209473
5 maccs 0.40606126033119555 0.1858224802938626 0.5030479431152344
6 help 0.17471509961619794 0.07995343640757792 0.49689680337905884
7 comp 0.16962298878652332 0.07762317554120124 0.5007365345954895
8 exp 0.16035562746707094 0.07338222907722322 0.5002157688140869
9 phys 0.0014117471939899774 0.0006460462761730848 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

View File

@ -1,9 +0,0 @@
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
mpnn,0.17711989597355013,0.21006733816477163,0.5024935007095337
desc,0.17162019114321028,0.20354459068882486,0.5030443072319031
morgan,0.16095514462538474,0.19089565635488848,0.5045571327209473
maccs,0.15742516053846328,0.18670903261717847,0.5030479431152344
comp,0.05954181250470194,0.07061764571178654,0.5007365345954895
help,0.05892528920918569,0.06988643814815398,0.49689680337905884
exp,0.05708834082906645,0.06770778478774685,0.5002157688140869
phys,0.0004818760368551862,0.0005715135266490605,0.49989768862724304
1 token Integrated Gradients_raw Integrated Gradients_normalized gate_sigmoid
2 mpnn 0.17711989597355013 0.21006733816477163 0.5024935007095337
3 desc 0.17162019114321028 0.20354459068882486 0.5030443072319031
4 morgan 0.16095514462538474 0.19089565635488848 0.5045571327209473
5 maccs 0.15742516053846328 0.18670903261717847 0.5030479431152344
6 comp 0.05954181250470194 0.07061764571178654 0.5007365345954895
7 help 0.05892528920918569 0.06988643814815398 0.49689680337905884
8 exp 0.05708834082906645 0.06770778478774685 0.5002157688140869
9 phys 0.0004818760368551862 0.0005715135266490605 0.49989768862724304

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB