mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 01:27:00 +08:00
移除无关任务
This commit is contained in:
parent
3b38727053
commit
ac5f598484
27
Makefile
27
Makefile
@ -123,25 +123,36 @@ train: requirements
|
||||
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG)
|
||||
|
||||
#################################################################################
|
||||
# INTERPRETABILITY #
|
||||
# INTERPRETABILITY (biodistribution feature importance) #
|
||||
#################################################################################
|
||||
# 参数:
|
||||
# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic, all; 默认: delivery)
|
||||
# 如果指定 'all',将依次计算所有 6 个任务
|
||||
# METHOD 方法 (ig, ablation, attention, all; 默认: ig)
|
||||
# DATA 数据路径 (默认: data/interim/internal.csv,即最终模型的全量训练数据)
|
||||
# ORGAN 器官 (lymph_nodes, heart, liver, spleen, lung, kidney, muscle, all; 默认: all)
|
||||
# METHOD token 级方法 (ig, ablation, attention, all; 默认: ig)
|
||||
# DESC_IG 同时计算 desc 内部特征 IG (1 启用; 默认不启用)
|
||||
# DESC_TOP_K 可视化展示的 top-K 特征数 (默认: 30)
|
||||
# DATA 数据路径 (默认: data/interim/internal.csv)
|
||||
# MODEL 模型路径 (默认: models/final/model.pt)
|
||||
|
||||
TASK_FLAG = $(if $(TASK),--task $(TASK),)
|
||||
METHOD_FLAG = $(if $(METHOD),--method $(METHOD),)
|
||||
DATA_FLAG = $(if $(DATA),--data-path $(DATA),)
|
||||
MODEL_FLAG = $(if $(MODEL),--model-path $(MODEL),)
|
||||
ORGAN_FLAG = $(if $(ORGAN),--organ $(ORGAN),)
|
||||
DESC_IG_FLAG = $(if $(DESC_IG),--desc-ig,)
|
||||
DESC_TOP_K_FLAG = $(if $(DESC_TOP_K),--desc-top-k $(DESC_TOP_K),)
|
||||
|
||||
## Compute token-level feature importance
|
||||
## Compute biodistribution feature importance (token-level)
|
||||
.PHONY: feature_importance
|
||||
feature_importance: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \
|
||||
$(TASK_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG)
|
||||
$(ORGAN_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG) \
|
||||
$(DESC_IG_FLAG) $(DESC_TOP_K_FLAG)
|
||||
|
||||
## Compute biodistribution feature importance (token + descriptor-level)
|
||||
.PHONY: desc_importance
|
||||
desc_importance: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \
|
||||
--desc-ig $(ORGAN_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) \
|
||||
$(DEVICE_FLAG) $(DESC_TOP_K_FLAG)
|
||||
|
||||
#################################################################################
|
||||
# SERVING & DEPLOYMENT #
|
||||
|
||||
@ -1,25 +1,26 @@
|
||||
"""
|
||||
Token-level feature importance for the LNP multi-task model.
|
||||
Biodistribution feature importance for the LNP multi-task model.
|
||||
|
||||
Three complementary methods (inspired by AGILE / Captum):
|
||||
1. Integrated Gradients — gradient-based attribution via Captum
|
||||
2. Token Ablation — zero-out each token and measure prediction change
|
||||
3. Fusion Attention Weights — extract learned attention pooling weights
|
||||
Two levels of analysis, both targeting the biodist head (per organ):
|
||||
1. Token-level — which of the 8 input tokens matters most
|
||||
2. Descriptor-level — which RDKit descriptors inside the "desc" token matter most
|
||||
|
||||
Token-level methods:
|
||||
- Integrated Gradients (Captum)
|
||||
- Token Ablation (zero-out)
|
||||
- Fusion Attention Weights
|
||||
|
||||
Usage:
|
||||
python -m lnp_ml.interpretability.token_importance \
|
||||
--model-path models/model.pt \
|
||||
--data-path data/processed/train.parquet \
|
||||
--task delivery \
|
||||
--method all
|
||||
python -m lnp_ml.interpretability.token_importance # all organs, token-level IG
|
||||
python -m lnp_ml.interpretability.token_importance --organ liver # liver only
|
||||
python -m lnp_ml.interpretability.token_importance --desc-ig # + desc-level IG
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -35,7 +36,7 @@ from lnp_ml.modeling.predict import load_model
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
|
||||
|
||||
TASKS = ["delivery", "size", "pdi", "ee", "biodist", "toxic"]
|
||||
BIODIST_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
|
||||
|
||||
|
||||
def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]:
|
||||
@ -45,96 +46,97 @@ def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]:
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Helper: pre-compute projected tokens for the whole dataset
|
||||
# Pre-compute
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def precompute_projected_tokens(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
device: torch.device,
|
||||
model: nn.Module, loader: DataLoader, device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run encoder + TokenProjector on all samples and stack results.
|
||||
|
||||
Returns:
|
||||
all_tokens: [N, n_tokens, d_model]
|
||||
"""
|
||||
"""Returns [N, n_tokens, d_model]."""
|
||||
model.eval()
|
||||
chunks = []
|
||||
for batch in tqdm(loader, desc="Encoding tokens"):
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
stacked = model._encode_and_project(smiles, tabular)
|
||||
chunks.append(stacked.cpu())
|
||||
chunks.append(model._encode_and_project(smiles, tabular).cpu())
|
||||
return torch.cat(chunks, dim=0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def precompute_raw_desc(
|
||||
model: nn.Module, loader: DataLoader, device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Returns [N, desc_dim] raw RDKit descriptor features."""
|
||||
model.eval()
|
||||
chunks = []
|
||||
for batch in tqdm(loader, desc="Encoding raw desc"):
|
||||
smiles = batch["smiles"]
|
||||
rdkit_features = model.rdkit_encoder(smiles)
|
||||
target_device = next(model.parameters()).device
|
||||
chunks.append(rdkit_features["desc"].to(target_device).cpu())
|
||||
return torch.cat(chunks, dim=0)
|
||||
|
||||
|
||||
def get_desc_names_and_dim(model: nn.Module) -> tuple[List[str], int]:
|
||||
"""Get descriptor names from RDKit and verify against model dim."""
|
||||
from rdkit.Chem import Descriptors as RDKitDesc
|
||||
names = [name for name, _ in RDKitDesc.descList]
|
||||
model_dim = model.token_projector.projectors["desc"][1].in_features
|
||||
if len(names) != model_dim:
|
||||
logger.warning(
|
||||
f"RDKit descList has {len(names)} names but model expects "
|
||||
f"{model_dim} features; using generic names."
|
||||
)
|
||||
names = [f"desc_{i}" for i in range(model_dim)]
|
||||
return names, model_dim
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Method 1: Integrated Gradients (Captum)
|
||||
# Token-level: Integrated Gradients
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
class _ProjectedWrapper(nn.Module):
|
||||
"""Wraps model.forward_from_projected so Captum can attribute w.r.t. stacked tokens."""
|
||||
|
||||
def __init__(self, model: nn.Module, task: str) -> None:
|
||||
def __init__(self, model: nn.Module, organ_index: int) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.task = task
|
||||
self.organ_index = organ_index
|
||||
|
||||
def forward(self, stacked: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
stacked: [B, n_tokens, d_model]
|
||||
Returns:
|
||||
[B] scalar predictions (sum over output dims for non-scalar heads)
|
||||
"""
|
||||
out = self.model.forward_from_projected(stacked, task=self.task)
|
||||
if out.dim() > 1 and out.size(-1) > 1:
|
||||
return out.sum(dim=-1)
|
||||
return out.squeeze(-1)
|
||||
out = self.model.forward_from_projected(stacked, task="biodist")
|
||||
return out[:, self.organ_index]
|
||||
|
||||
|
||||
def integrated_gradients_importance(
|
||||
model: nn.Module,
|
||||
all_tokens: torch.Tensor,
|
||||
device: torch.device,
|
||||
task: str = "delivery",
|
||||
organ_index: int,
|
||||
batch_size: int = 64,
|
||||
n_steps: int = 50,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute token-level Integrated Gradients.
|
||||
|
||||
Returns:
|
||||
importance: [n_tokens] averaged absolute attribution per token position
|
||||
"""
|
||||
"""Returns [n_tokens] averaged absolute attribution per token."""
|
||||
from captum.attr import IntegratedGradients
|
||||
|
||||
wrapper = _ProjectedWrapper(model, task).to(device)
|
||||
wrapper = _ProjectedWrapper(model, organ_index).to(device)
|
||||
wrapper.eval()
|
||||
ig = IntegratedGradients(wrapper)
|
||||
|
||||
N = all_tokens.size(0)
|
||||
all_attrs: List[torch.Tensor] = []
|
||||
|
||||
for start in tqdm(range(0, N, batch_size), desc=f"IG ({task})"):
|
||||
for start in tqdm(range(0, N, batch_size), desc=f"IG (organ={organ_index})"):
|
||||
end = min(start + batch_size, N)
|
||||
inp = all_tokens[start:end].to(device).requires_grad_(True)
|
||||
baseline = torch.zeros_like(inp)
|
||||
|
||||
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
|
||||
# attr: [B, n_tokens, d_model] → per-token L2 norm → [B, n_tokens]
|
||||
token_attr = attr.detach().cpu().norm(dim=-1)
|
||||
all_attrs.append(token_attr)
|
||||
all_attrs.append(attr.detach().cpu().norm(dim=-1)) # [B, n_tokens]
|
||||
|
||||
attrs = torch.cat(all_attrs, dim=0) # [N, n_tokens]
|
||||
importance = attrs.mean(dim=0).numpy() # [n_tokens]
|
||||
return importance
|
||||
return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Method 2: Token Ablation (zero-out)
|
||||
# Token-level: Ablation
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
@ -142,98 +144,113 @@ def token_ablation_importance(
|
||||
model: nn.Module,
|
||||
all_tokens: torch.Tensor,
|
||||
device: torch.device,
|
||||
task: str = "delivery",
|
||||
organ_index: int,
|
||||
batch_size: int = 64,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
For each token position, replace it with zeros and measure
|
||||
the average absolute prediction change.
|
||||
|
||||
Returns:
|
||||
importance: [n_tokens]
|
||||
"""
|
||||
"""Returns [n_tokens] importance via zero-out ablation."""
|
||||
model.eval()
|
||||
n_tokens = all_tokens.size(1)
|
||||
N = all_tokens.size(0)
|
||||
|
||||
# original predictions
|
||||
orig_preds = _batch_predict(model, all_tokens, device, task, batch_size)
|
||||
orig_preds = _batch_predict(model, all_tokens, device, organ_index, batch_size)
|
||||
|
||||
importance = np.zeros(n_tokens)
|
||||
for t in range(n_tokens):
|
||||
ablated = all_tokens.clone()
|
||||
ablated[:, t, :] = 0.0
|
||||
abl_preds = _batch_predict(model, ablated, device, task, batch_size)
|
||||
abl_preds = _batch_predict(model, ablated, device, organ_index, batch_size)
|
||||
importance[t] = np.abs(orig_preds - abl_preds).mean()
|
||||
|
||||
return importance
|
||||
|
||||
|
||||
def _batch_predict(
|
||||
model: nn.Module,
|
||||
all_tokens: torch.Tensor,
|
||||
device: torch.device,
|
||||
task: str,
|
||||
batch_size: int,
|
||||
model: nn.Module, all_tokens: torch.Tensor,
|
||||
device: torch.device, organ_index: int, batch_size: int,
|
||||
) -> np.ndarray:
|
||||
"""Run forward_from_projected in batches, return [N] predictions."""
|
||||
model.eval()
|
||||
preds = []
|
||||
N = all_tokens.size(0)
|
||||
for start in range(0, N, batch_size):
|
||||
end = min(start + batch_size, N)
|
||||
inp = all_tokens[start:end].to(device)
|
||||
out = model.forward_from_projected(inp, task=task)
|
||||
if out.dim() > 1 and out.size(-1) > 1:
|
||||
out = out.sum(dim=-1)
|
||||
else:
|
||||
out = out.squeeze(-1)
|
||||
preds.append(out.cpu().numpy())
|
||||
return np.concatenate(preds, axis=0)
|
||||
for start in range(0, all_tokens.size(0), batch_size):
|
||||
end = min(start + batch_size, all_tokens.size(0))
|
||||
out = model.forward_from_projected(all_tokens[start:end].to(device), task="biodist")
|
||||
preds.append(out[:, organ_index].cpu().numpy())
|
||||
return np.concatenate(preds)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Method 3: Fusion Attention Weights
|
||||
# Token-level: Fusion Attention Weights (task-agnostic)
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def fusion_attention_importance(
|
||||
model: nn.Module,
|
||||
all_tokens: torch.Tensor,
|
||||
device: torch.device,
|
||||
batch_size: int = 64,
|
||||
model: nn.Module, all_tokens: torch.Tensor,
|
||||
device: torch.device, batch_size: int = 64,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Extract FusionLayer attention weights (only works if strategy="attention").
|
||||
|
||||
Returns:
|
||||
importance: [n_tokens] averaged attention weights, or None if not applicable.
|
||||
"""
|
||||
model.eval()
|
||||
if model.fusion.strategy != "attention":
|
||||
logger.warning("Fusion strategy is not 'attention'; skipping attention weight extraction.")
|
||||
logger.warning("Fusion strategy is not 'attention'; skipping.")
|
||||
return None
|
||||
|
||||
N = all_tokens.size(0)
|
||||
all_weights: List[torch.Tensor] = []
|
||||
|
||||
for start in range(0, N, batch_size):
|
||||
end = min(start + batch_size, N)
|
||||
inp = all_tokens[start:end].to(device)
|
||||
attended = model.cross_attention(inp)
|
||||
for start in range(0, all_tokens.size(0), batch_size):
|
||||
end = min(start + batch_size, all_tokens.size(0))
|
||||
attended = model.cross_attention(all_tokens[start:end].to(device))
|
||||
_, weights = model.fusion(attended, return_attn_weights=True)
|
||||
all_weights.append(weights.cpu())
|
||||
|
||||
weights = torch.cat(all_weights, dim=0) # [N, n_tokens]
|
||||
return weights.mean(dim=0).numpy()
|
||||
return torch.cat(all_weights, dim=0).mean(dim=0).numpy()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Bonus: TokenProjector gate values (static)
|
||||
# Descriptor-level: Integrated Gradients
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
class _ReplacingDescWrapper(nn.Module):
|
||||
def __init__(self, model: nn.Module, base_projected: torch.Tensor, organ_index: int) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.base_projected = base_projected
|
||||
self.organ_index = organ_index
|
||||
|
||||
def forward(self, raw_desc: torch.Tensor) -> torch.Tensor:
|
||||
base = self.base_projected[:raw_desc.size(0)].to(raw_desc.device)
|
||||
out = self.model.forward_replacing_token(raw_desc, "desc", base, task="biodist")
|
||||
return out[:, self.organ_index]
|
||||
|
||||
|
||||
def descriptor_ig_importance(
|
||||
model: nn.Module,
|
||||
all_tokens: torch.Tensor,
|
||||
raw_desc: torch.Tensor,
|
||||
device: torch.device,
|
||||
organ_index: int,
|
||||
batch_size: int = 64,
|
||||
n_steps: int = 50,
|
||||
) -> np.ndarray:
|
||||
"""Returns [desc_dim] averaged absolute attribution per descriptor."""
|
||||
from captum.attr import IntegratedGradients
|
||||
|
||||
wrapper = _ReplacingDescWrapper(model, all_tokens, organ_index).to(device)
|
||||
wrapper.eval()
|
||||
ig = IntegratedGradients(wrapper)
|
||||
|
||||
N = raw_desc.size(0)
|
||||
all_attrs: List[torch.Tensor] = []
|
||||
|
||||
for start in tqdm(range(0, N, batch_size), desc=f"Desc-IG (organ={organ_index})"):
|
||||
end = min(start + batch_size, N)
|
||||
inp = raw_desc[start:end].to(device).requires_grad_(True)
|
||||
baseline = torch.zeros_like(inp)
|
||||
wrapper.base_projected = all_tokens[start:end]
|
||||
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
|
||||
all_attrs.append(attr.detach().cpu().abs())
|
||||
|
||||
return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Gate values
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def gate_values(model: nn.Module) -> Dict[str, float]:
|
||||
"""Read sigmoid(weight) from TokenProjector for each token."""
|
||||
gates = {}
|
||||
for key in model.token_projector.keys:
|
||||
w = model.token_projector.weights[key].detach().cpu()
|
||||
@ -242,7 +259,7 @@ def gate_values(model: nn.Module) -> Dict[str, float]:
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Visualization
|
||||
# Visualization & IO
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def normalize(arr: np.ndarray) -> np.ndarray:
|
||||
@ -251,10 +268,8 @@ def normalize(arr: np.ndarray) -> np.ndarray:
|
||||
|
||||
|
||||
def plot_token_importance(
|
||||
results: Dict[str, np.ndarray],
|
||||
token_names: List[str],
|
||||
task: str,
|
||||
out_dir: Path,
|
||||
results: Dict[str, np.ndarray], token_names: List[str],
|
||||
organ: str, out_dir: Path,
|
||||
) -> Path:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
@ -265,20 +280,18 @@ def plot_token_importance(
|
||||
if n_methods == 1:
|
||||
axes = [axes]
|
||||
|
||||
channel_a_color = "#5a448e"
|
||||
channel_b_color = "#e07b39"
|
||||
color_a, color_b = "#5a448e", "#e07b39"
|
||||
|
||||
for ax, (method_name, importance) in zip(axes, results.items()):
|
||||
normed = normalize(importance)
|
||||
order = np.argsort(normed)
|
||||
|
||||
names_sorted = [token_names[i] for i in order]
|
||||
vals_sorted = normed[order]
|
||||
|
||||
n_tokens = len(token_names)
|
||||
split_idx = 4 if n_tokens == 8 else 3
|
||||
channel_a_set = set(token_names[:split_idx])
|
||||
colors = [channel_a_color if n in channel_a_set else channel_b_color for n in names_sorted]
|
||||
colors = [color_a if n in channel_a_set else color_b for n in names_sorted]
|
||||
|
||||
ax.barh(range(len(names_sorted)), vals_sorted, color=colors)
|
||||
ax.set_yticks(range(len(names_sorted)))
|
||||
@ -288,22 +301,20 @@ def plot_token_importance(
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
|
||||
fig.suptitle(f"Token Importance — task: {task}", fontsize=14, y=1.02)
|
||||
fig.suptitle(f"Token Importance — biodist: {organ}", fontsize=14, y=1.02)
|
||||
fig.tight_layout()
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
fig_path = out_dir / f"token_importance_{task}.png"
|
||||
fig_path = out_dir / f"token_importance_biodist_{organ}.png"
|
||||
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
logger.info(f"Figure saved to {fig_path}")
|
||||
return fig_path
|
||||
|
||||
|
||||
def save_csv(
|
||||
results: Dict[str, np.ndarray],
|
||||
token_names: List[str],
|
||||
task: str,
|
||||
out_dir: Path,
|
||||
def save_token_csv(
|
||||
results: Dict[str, np.ndarray], token_names: List[str],
|
||||
organ: str, out_dir: Path,
|
||||
gate_vals: Optional[Dict[str, float]] = None,
|
||||
) -> Path:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -312,38 +323,98 @@ def save_csv(
|
||||
normed = normalize(importance)
|
||||
df[f"{method_name}_raw"] = importance
|
||||
df[f"{method_name}_normalized"] = normed
|
||||
|
||||
if gate_vals is not None:
|
||||
df["gate_sigmoid"] = [gate_vals.get(t, float("nan")) for t in token_names]
|
||||
|
||||
df = df.sort_values(
|
||||
by=[c for c in df.columns if c.endswith("_normalized")][-1],
|
||||
ascending=False,
|
||||
)
|
||||
csv_path = out_dir / f"token_importance_{task}.csv"
|
||||
csv_path = out_dir / f"token_importance_biodist_{organ}.csv"
|
||||
df.to_csv(csv_path, index=False)
|
||||
logger.info(f"CSV saved to {csv_path}")
|
||||
return csv_path
|
||||
|
||||
|
||||
def plot_descriptor_importance(
|
||||
importance: np.ndarray, feature_names: List[str],
|
||||
organ: str, out_dir: Path, top_k: int = 30,
|
||||
) -> Path:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
normed = normalize(importance)
|
||||
order = np.argsort(-normed)
|
||||
top_indices = order[:top_k]
|
||||
|
||||
names = [feature_names[i] for i in reversed(top_indices)]
|
||||
vals = [normed[i] for i in reversed(top_indices)]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, max(5, top_k * 0.35)))
|
||||
ax.barh(range(len(names)), vals, color="#5a448e")
|
||||
ax.set_yticks(range(len(names)))
|
||||
ax.set_yticklabels(names, fontsize=9)
|
||||
ax.set_xlabel("Normalized Importance", fontsize=11)
|
||||
ax.set_title(
|
||||
f"Top-{top_k} desc Feature Importance — biodist: {organ}\n"
|
||||
f"(total features: {len(feature_names)})",
|
||||
fontsize=13,
|
||||
)
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
fig.tight_layout()
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
fig_path = out_dir / f"desc_importance_biodist_{organ}.png"
|
||||
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
logger.info(f"Descriptor figure saved to {fig_path}")
|
||||
return fig_path
|
||||
|
||||
|
||||
def save_descriptor_csv(
|
||||
importance: np.ndarray, feature_names: List[str],
|
||||
organ: str, out_dir: Path,
|
||||
) -> Path:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
normed = normalize(importance)
|
||||
df = pd.DataFrame({
|
||||
"feature": feature_names,
|
||||
"ig_raw": importance,
|
||||
"ig_normalized": normed,
|
||||
})
|
||||
df = df.sort_values("ig_normalized", ascending=False).reset_index(drop=True)
|
||||
df.index.name = "rank"
|
||||
csv_path = out_dir / f"desc_importance_biodist_{organ}.csv"
|
||||
df.to_csv(csv_path)
|
||||
logger.info(f"Descriptor CSV saved to {csv_path}")
|
||||
return csv_path
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Main
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Token-level Feature Importance")
|
||||
parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "final" / "model.pt"),
|
||||
help="Path to trained model checkpoint")
|
||||
parser.add_argument("--data-path", type=str, default=str(INTERIM_DATA_DIR / "internal.csv"),
|
||||
help="Path to data (.csv or .parquet) for computing importance")
|
||||
parser.add_argument("--task", type=str, default="all", choices=TASKS + ["all"],
|
||||
help="Target task for importance computation ('all' to run on all tasks)")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Biodistribution feature importance (token-level & descriptor-level)")
|
||||
parser.add_argument("--model-path", type=str,
|
||||
default=str(MODELS_DIR / "final" / "model.pt"))
|
||||
parser.add_argument("--data-path", type=str,
|
||||
default=str(INTERIM_DATA_DIR / "internal.csv"))
|
||||
parser.add_argument("--organ", type=str, default="all",
|
||||
choices=BIODIST_ORGANS + ["all"],
|
||||
help="Organ to analyze (default: all 7 organs)")
|
||||
parser.add_argument("--method", type=str, default="ig",
|
||||
choices=["ig", "ablation", "attention", "all"],
|
||||
help="Which method(s) to run")
|
||||
help="Token-level method(s)")
|
||||
parser.add_argument("--desc-ig", action="store_true",
|
||||
help="Also compute descriptor-level IG within the desc token")
|
||||
parser.add_argument("--desc-top-k", type=int, default=30,
|
||||
help="Top-K descriptors to show in plot")
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument("--n-steps", type=int, default=50,
|
||||
help="Number of interpolation steps for Integrated Gradients")
|
||||
help="Interpolation steps for Integrated Gradients")
|
||||
parser.add_argument("--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--output-dir", type=str,
|
||||
@ -352,17 +423,21 @@ def main() -> None:
|
||||
|
||||
device = torch.device(args.device)
|
||||
out_dir = Path(args.output_dir)
|
||||
|
||||
# ── Resolve organs ──
|
||||
organs = BIODIST_ORGANS if args.organ == "all" else [args.organ]
|
||||
organ_indices = {organ: BIODIST_ORGANS.index(organ) for organ in organs}
|
||||
|
||||
logger.info(f"Device: {device}")
|
||||
logger.info(f"Model: {args.model_path}")
|
||||
logger.info(f"Data: {args.data_path}")
|
||||
logger.info(f"Task: {args.task}")
|
||||
logger.info(f"Organs: {organs}")
|
||||
|
||||
# ── Load model ──
|
||||
# ── Load model & data ──
|
||||
model = load_model(Path(args.model_path), device)
|
||||
token_names = get_token_names(model)
|
||||
logger.info(f"Tokens ({len(token_names)}): {token_names}")
|
||||
|
||||
# ── Load data ──
|
||||
data_path = Path(args.data_path)
|
||||
if data_path.suffix == ".csv":
|
||||
df = pd.read_csv(data_path)
|
||||
@ -373,76 +448,78 @@ def main() -> None:
|
||||
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
||||
logger.info(f"Samples: {len(dataset)}")
|
||||
|
||||
# ── Pre-compute projected tokens ──
|
||||
# ── Pre-compute ──
|
||||
all_tokens = precompute_projected_tokens(model, loader, device)
|
||||
logger.info(f"Projected tokens shape: {all_tokens.shape}")
|
||||
|
||||
# ── Gate values (always) ──
|
||||
gv = gate_values(model)
|
||||
logger.info(f"TokenProjector gate values: {gv}")
|
||||
logger.info(f"Gate values: {gv}")
|
||||
|
||||
# ── Run selected methods ──
|
||||
methods = (
|
||||
["ig", "ablation", "attention"] if args.method == "all"
|
||||
else [args.method]
|
||||
)
|
||||
methods = ["ig", "ablation", "attention"] if args.method == "all" else [args.method]
|
||||
|
||||
# Determine tasks to process
|
||||
tasks_to_run = TASKS if args.task == "all" else [args.task]
|
||||
# ── Pre-compute desc features if needed ──
|
||||
raw_desc = None
|
||||
desc_names = None
|
||||
if args.desc_ig:
|
||||
desc_names, desc_dim = get_desc_names_and_dim(model)
|
||||
raw_desc = precompute_raw_desc(model, loader, device)
|
||||
logger.info(f"Raw desc shape: {raw_desc.shape} (dim={desc_dim})")
|
||||
assert raw_desc.size(1) == desc_dim
|
||||
|
||||
for task in tasks_to_run:
|
||||
# ── Per-organ loop ──
|
||||
for organ, organ_idx in organ_indices.items():
|
||||
logger.info(f"\n{'#'*60}")
|
||||
logger.info(f"# Processing task: {task}")
|
||||
logger.info(f"# Organ: {organ} (index={organ_idx})")
|
||||
logger.info(f"{'#'*60}")
|
||||
|
||||
# Token-level
|
||||
results: Dict[str, np.ndarray] = {}
|
||||
|
||||
for method in methods:
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Computing: {method} (task={task})")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
logger.info(f" Computing: {method}")
|
||||
if method == "ig":
|
||||
imp = integrated_gradients_importance(
|
||||
model, all_tokens, device,
|
||||
task=task, batch_size=args.batch_size, n_steps=args.n_steps,
|
||||
results["Integrated Gradients"] = integrated_gradients_importance(
|
||||
model, all_tokens, device, organ_idx,
|
||||
batch_size=args.batch_size, n_steps=args.n_steps,
|
||||
)
|
||||
results["Integrated Gradients"] = imp
|
||||
|
||||
elif method == "ablation":
|
||||
imp = token_ablation_importance(
|
||||
model, all_tokens, device,
|
||||
task=task, batch_size=args.batch_size,
|
||||
)
|
||||
results["Token Ablation"] = imp
|
||||
|
||||
elif method == "attention":
|
||||
imp = fusion_attention_importance(
|
||||
model, all_tokens, device,
|
||||
results["Token Ablation"] = token_ablation_importance(
|
||||
model, all_tokens, device, organ_idx,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
elif method == "attention":
|
||||
imp = fusion_attention_importance(model, all_tokens, device, args.batch_size)
|
||||
if imp is not None:
|
||||
results["Fusion Attention"] = imp
|
||||
|
||||
# ── Print summary ──
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Token Importance Summary (task={task})")
|
||||
logger.info(f"{'='*60}")
|
||||
# Print token summary
|
||||
for method_name, importance in results.items():
|
||||
normed = normalize(importance)
|
||||
order = np.argsort(-normed)
|
||||
logger.info(f"\n {method_name}:")
|
||||
logger.info(f"\n {method_name} (biodist_{organ}):")
|
||||
for rank, idx in enumerate(order, 1):
|
||||
logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}")
|
||||
|
||||
logger.info(f"\n Gate values (sigmoid):")
|
||||
for name, val in sorted(gv.items(), key=lambda x: -x[1]):
|
||||
logger.info(f" {name:<10s} {val:.4f}")
|
||||
|
||||
# ── Save results ──
|
||||
if results:
|
||||
plot_token_importance(results, token_names, task, out_dir)
|
||||
save_csv(results, token_names, task, out_dir, gate_vals=gv)
|
||||
plot_token_importance(results, token_names, organ, out_dir)
|
||||
save_token_csv(results, token_names, organ, out_dir, gate_vals=gv)
|
||||
|
||||
# Descriptor-level
|
||||
if args.desc_ig and raw_desc is not None and desc_names is not None:
|
||||
logger.info(f"\n Descriptor IG for {organ}...")
|
||||
desc_imp = descriptor_ig_importance(
|
||||
model, all_tokens, raw_desc, device, organ_idx,
|
||||
batch_size=args.batch_size, n_steps=args.n_steps,
|
||||
)
|
||||
normed = normalize(desc_imp)
|
||||
top_k = min(args.desc_top_k, len(desc_names))
|
||||
order = np.argsort(-normed)
|
||||
logger.info(f"\n Top-{top_k} desc features (biodist_{organ}):")
|
||||
for rank, idx in enumerate(order[:top_k], 1):
|
||||
logger.info(f" {rank:>3d}. {desc_names[idx]:<30s} {normed[idx]:.6f}")
|
||||
|
||||
desc_out_dir = out_dir / "desc"
|
||||
plot_descriptor_importance(desc_imp, desc_names, organ, desc_out_dir, top_k)
|
||||
save_descriptor_csv(desc_imp, desc_names, organ, desc_out_dir)
|
||||
|
||||
logger.info("\nDone!")
|
||||
|
||||
|
||||
@ -201,6 +201,39 @@ class LNPModel(nn.Module):
|
||||
}
|
||||
return task_heads[task](fused)
|
||||
|
||||
def forward_replacing_token(
|
||||
self,
|
||||
raw_feature: torch.Tensor,
|
||||
feature_key: str,
|
||||
base_projected: torch.Tensor,
|
||||
task: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
用原始特征替换 base_projected 中指定 token 的投影,然后 forward。
|
||||
|
||||
用于对单个 token 内部特征做 Captum 归因(如 desc 的 210 维)。
|
||||
|
||||
Args:
|
||||
raw_feature: [B, input_dim] 某个 token 的原始特征
|
||||
feature_key: token 名称,如 "desc"
|
||||
base_projected: [B, n_tokens, d_model] 其他 token 已投影好的张量
|
||||
task: 任务名
|
||||
|
||||
Returns:
|
||||
对应任务的预测输出
|
||||
"""
|
||||
projected = self.token_projector.projectors[feature_key](raw_feature)
|
||||
gate = torch.sigmoid(self.token_projector.weights[feature_key])
|
||||
projected = projected * gate # [B, d_model]
|
||||
|
||||
token_order = list(self.token_projector.keys)
|
||||
token_idx = token_order.index(feature_key)
|
||||
|
||||
stacked = base_projected.clone()
|
||||
stacked[:, token_idx, :] = projected
|
||||
|
||||
return self.forward_from_projected(stacked, task=task)
|
||||
|
||||
def forward_backbone(
|
||||
self,
|
||||
smiles: List[str],
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
desc,3.2013970842762367e-09,0.17906809140868585,0.5030443072319031
|
||||
mpnn,3.109777150387161e-09,0.17394338920379962,0.5024935007095337
|
||||
maccs,3.0657202874248063e-09,0.1714790968475434,0.5030479431152344
|
||||
morgan,3.020539287877718e-09,0.16895192663283606,0.5045571327209473
|
||||
help,1.937320535640997e-09,0.10836278088337024,0.49689680337905884
|
||||
comp,1.876732953403087e-09,0.10497385335304418,0.5007365345954895
|
||||
exp,1.6503372406931126e-09,0.09231055445232395,0.5002157688140869
|
||||
phys,1.6274562664095572e-11,0.0009103072183966515,0.49989768862724304
|
||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 55 KiB |
@ -1,9 +0,0 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
desc,0.48309860429978513,0.1830685579048154,0.5030443072319031
|
||||
mpnn,0.4800125818662255,0.1818991202961257,0.5024935007095337
|
||||
morgan,0.4681746999619525,0.17741319558101734,0.5045571327209473
|
||||
maccs,0.4642216987644718,0.1759152193455701,0.5030479431152344
|
||||
help,0.2636187040391749,0.09989740304701922,0.49689680337905884
|
||||
comp,0.2503036292270154,0.0948517011498054,0.5007365345954895
|
||||
exp,0.22730758172642437,0.08613742788142016,0.5002157688140869
|
||||
phys,0.0021569658208921167,0.0008173747942266034,0.49989768862724304
|
||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 56 KiB |
@ -1,9 +0,0 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
desc,0.28693426746338735,0.2025623449841757,0.5030443072319031
|
||||
mpnn,0.2846832493908164,0.2009732301551499,0.5024935007095337
|
||||
morgan,0.2646961202937205,0.18686323982460912,0.5045571327209473
|
||||
maccs,0.25845640337994125,0.1824582877731647,0.5030479431152344
|
||||
help,0.11619691669315178,0.082029658337337,0.49689680337905884
|
||||
comp,0.10812802919624785,0.07633339630758515,0.5007365345954895
|
||||
exp,0.09640860187977682,0.06806002171178546,0.5002157688140869
|
||||
phys,0.001019643036021533,0.000719820906192951,0.49989768862724304
|
||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 54 KiB |
@ -1,9 +0,0 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
mpnn,0.29288250773521896,0.2140725741095052,0.5024935007095337
|
||||
desc,0.2764989421166159,0.2020975603328644,0.5030443072319031
|
||||
morgan,0.2578500855439716,0.18846680866532511,0.5045571327209473
|
||||
maccs,0.2471777274240024,0.18066620905892686,0.5030479431152344
|
||||
help,0.10363027887841153,0.07574505123823679,0.49689680337905884
|
||||
comp,0.09781463069792998,0.07149429968008479,0.5007365345954895
|
||||
exp,0.091429982122356,0.06682765650659303,0.5002157688140869
|
||||
phys,0.0008617135523839594,0.0006298404084638948,0.49989768862724304
|
||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 54 KiB |
@ -1,9 +0,0 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
desc,0.43124767207505,0.1973483310882186,0.5030443072319031
|
||||
mpnn,0.4250195774970647,0.1944982193069518,0.5024935007095337
|
||||
morgan,0.41677666295297405,0.19072608200879151,0.5045571327209473
|
||||
maccs,0.40606126033119555,0.1858224802938626,0.5030479431152344
|
||||
help,0.17471509961619794,0.07995343640757792,0.49689680337905884
|
||||
comp,0.16962298878652332,0.07762317554120124,0.5007365345954895
|
||||
exp,0.16035562746707094,0.07338222907722322,0.5002157688140869
|
||||
phys,0.0014117471939899774,0.0006460462761730848,0.49989768862724304
|
||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 55 KiB |
@ -1,9 +0,0 @@
|
||||
token,Integrated Gradients_raw,Integrated Gradients_normalized,gate_sigmoid
|
||||
mpnn,0.17711989597355013,0.21006733816477163,0.5024935007095337
|
||||
desc,0.17162019114321028,0.20354459068882486,0.5030443072319031
|
||||
morgan,0.16095514462538474,0.19089565635488848,0.5045571327209473
|
||||
maccs,0.15742516053846328,0.18670903261717847,0.5030479431152344
|
||||
comp,0.05954181250470194,0.07061764571178654,0.5007365345954895
|
||||
help,0.05892528920918569,0.06988643814815398,0.49689680337905884
|
||||
exp,0.05708834082906645,0.06770778478774685,0.5002157688140869
|
||||
phys,0.0004818760368551862,0.0005715135266490605,0.49989768862724304
|
||||
|
Binary file not shown.
|
Before Width: | Height: | Size: 55 KiB |
Loading…
x
Reference in New Issue
Block a user