lnp_ml/lnp_ml/interpretability/token_importance.py
2026-03-03 15:23:10 +08:00

535 lines
22 KiB
Python

"""
Biodistribution feature importance for the LNP multi-task model.
Two levels of analysis, both targeting the biodist head (per organ):
1. Token-level — which of the 8 input tokens matters most
2. Descriptor-level — which RDKit descriptors inside the "desc" token matter most
Token-level methods:
- Integrated Gradients (Captum)
- Token Ablation (zero-out)
- Fusion Attention Weights
Usage:
python -m lnp_ml.interpretability.token_importance # all organs, token-level IG
python -m lnp_ml.interpretability.token_importance --organ liver # liver only
python -m lnp_ml.interpretability.token_importance --desc-ig # + desc-level IG
"""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from loguru import logger
from tqdm import tqdm
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR, REPORTS_DIR
from lnp_ml.dataset import LNPDataset, collate_fn, process_dataframe
from lnp_ml.modeling.predict import load_model
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
BIODIST_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]:
if model.use_mpnn:
return ["mpnn", "morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
return ["morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
# ──────────────────────────────────────────────────────────────────────
# Pre-compute
# ──────────────────────────────────────────────────────────────────────
@torch.no_grad()
def precompute_projected_tokens(
model: nn.Module, loader: DataLoader, device: torch.device,
) -> torch.Tensor:
"""Returns [N, n_tokens, d_model]."""
model.eval()
chunks = []
for batch in tqdm(loader, desc="Encoding tokens"):
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
chunks.append(model._encode_and_project(smiles, tabular).cpu())
return torch.cat(chunks, dim=0)
@torch.no_grad()
def precompute_raw_desc(
model: nn.Module, loader: DataLoader, device: torch.device,
) -> torch.Tensor:
"""Returns [N, desc_dim] raw RDKit descriptor features."""
model.eval()
chunks = []
for batch in tqdm(loader, desc="Encoding raw desc"):
smiles = batch["smiles"]
rdkit_features = model.rdkit_encoder(smiles)
target_device = next(model.parameters()).device
chunks.append(rdkit_features["desc"].to(target_device).cpu())
return torch.cat(chunks, dim=0)
def get_desc_names_and_dim(model: nn.Module) -> tuple[List[str], int]:
"""Get descriptor names from RDKit and verify against model dim."""
from rdkit.Chem import Descriptors as RDKitDesc
names = [name for name, _ in RDKitDesc.descList]
model_dim = model.token_projector.projectors["desc"][1].in_features
if len(names) != model_dim:
logger.warning(
f"RDKit descList has {len(names)} names but model expects "
f"{model_dim} features; using generic names."
)
names = [f"desc_{i}" for i in range(model_dim)]
return names, model_dim
# ──────────────────────────────────────────────────────────────────────
# Token-level: Integrated Gradients
# ──────────────────────────────────────────────────────────────────────
class _ProjectedWrapper(nn.Module):
def __init__(self, model: nn.Module, organ_index: int) -> None:
super().__init__()
self.model = model
self.organ_index = organ_index
def forward(self, stacked: torch.Tensor) -> torch.Tensor:
out = self.model.forward_from_projected(stacked, task="biodist")
return out[:, self.organ_index]
def integrated_gradients_importance(
model: nn.Module,
all_tokens: torch.Tensor,
device: torch.device,
organ_index: int,
batch_size: int = 64,
n_steps: int = 50,
) -> np.ndarray:
"""Returns [n_tokens] averaged absolute attribution per token."""
from captum.attr import IntegratedGradients
wrapper = _ProjectedWrapper(model, organ_index).to(device)
wrapper.eval()
ig = IntegratedGradients(wrapper)
N = all_tokens.size(0)
all_attrs: List[torch.Tensor] = []
for start in tqdm(range(0, N, batch_size), desc=f"IG (organ={organ_index})"):
end = min(start + batch_size, N)
inp = all_tokens[start:end].to(device).requires_grad_(True)
baseline = torch.zeros_like(inp)
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
all_attrs.append(attr.detach().cpu().norm(dim=-1)) # [B, n_tokens]
return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
# ──────────────────────────────────────────────────────────────────────
# Token-level: Ablation
# ──────────────────────────────────────────────────────────────────────
@torch.no_grad()
def token_ablation_importance(
model: nn.Module,
all_tokens: torch.Tensor,
device: torch.device,
organ_index: int,
batch_size: int = 64,
) -> np.ndarray:
"""Returns [n_tokens] importance via zero-out ablation."""
model.eval()
n_tokens = all_tokens.size(1)
orig_preds = _batch_predict(model, all_tokens, device, organ_index, batch_size)
importance = np.zeros(n_tokens)
for t in range(n_tokens):
ablated = all_tokens.clone()
ablated[:, t, :] = 0.0
abl_preds = _batch_predict(model, ablated, device, organ_index, batch_size)
importance[t] = np.abs(orig_preds - abl_preds).mean()
return importance
def _batch_predict(
model: nn.Module, all_tokens: torch.Tensor,
device: torch.device, organ_index: int, batch_size: int,
) -> np.ndarray:
model.eval()
preds = []
for start in range(0, all_tokens.size(0), batch_size):
end = min(start + batch_size, all_tokens.size(0))
out = model.forward_from_projected(all_tokens[start:end].to(device), task="biodist")
preds.append(out[:, organ_index].cpu().numpy())
return np.concatenate(preds)
# ──────────────────────────────────────────────────────────────────────
# Token-level: Fusion Attention Weights (task-agnostic)
# ──────────────────────────────────────────────────────────────────────
@torch.no_grad()
def fusion_attention_importance(
model: nn.Module, all_tokens: torch.Tensor,
device: torch.device, batch_size: int = 64,
) -> Optional[np.ndarray]:
model.eval()
if model.fusion.strategy != "attention":
logger.warning("Fusion strategy is not 'attention'; skipping.")
return None
all_weights: List[torch.Tensor] = []
for start in range(0, all_tokens.size(0), batch_size):
end = min(start + batch_size, all_tokens.size(0))
attended = model.cross_attention(all_tokens[start:end].to(device))
_, weights = model.fusion(attended, return_attn_weights=True)
all_weights.append(weights.cpu())
return torch.cat(all_weights, dim=0).mean(dim=0).numpy()
# ──────────────────────────────────────────────────────────────────────
# Descriptor-level: Integrated Gradients
# ──────────────────────────────────────────────────────────────────────
class _ReplacingDescWrapper(nn.Module):
def __init__(self, model: nn.Module, base_projected: torch.Tensor, organ_index: int) -> None:
super().__init__()
self.model = model
self.base_projected = base_projected
self.organ_index = organ_index
def forward(self, raw_desc: torch.Tensor) -> torch.Tensor:
B_base = self.base_projected.size(0)
B_input = raw_desc.size(0)
if B_input > B_base:
repeats = B_input // B_base
base = self.base_projected.repeat(repeats, 1, 1).to(raw_desc.device)
else:
base = self.base_projected[:B_input].to(raw_desc.device)
out = self.model.forward_replacing_token(raw_desc, "desc", base, task="biodist")
return out[:, self.organ_index]
def descriptor_ig_importance(
model: nn.Module,
all_tokens: torch.Tensor,
raw_desc: torch.Tensor,
device: torch.device,
organ_index: int,
batch_size: int = 64,
n_steps: int = 50,
) -> np.ndarray:
"""Returns [desc_dim] averaged absolute attribution per descriptor."""
from captum.attr import IntegratedGradients
wrapper = _ReplacingDescWrapper(model, all_tokens, organ_index).to(device)
wrapper.eval()
ig = IntegratedGradients(wrapper)
N = raw_desc.size(0)
all_attrs: List[torch.Tensor] = []
for start in tqdm(range(0, N, batch_size), desc=f"Desc-IG (organ={organ_index})"):
end = min(start + batch_size, N)
inp = raw_desc[start:end].to(device).requires_grad_(True)
baseline = torch.zeros_like(inp)
wrapper.base_projected = all_tokens[start:end]
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
all_attrs.append(attr.detach().cpu().abs())
return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
# ──────────────────────────────────────────────────────────────────────
# Gate values
# ──────────────────────────────────────────────────────────────────────
def gate_values(model: nn.Module) -> Dict[str, float]:
gates = {}
for key in model.token_projector.keys:
w = model.token_projector.weights[key].detach().cpu()
gates[key] = float(torch.sigmoid(w).item())
return gates
# ──────────────────────────────────────────────────────────────────────
# Visualization & IO
# ──────────────────────────────────────────────────────────────────────
def normalize(arr: np.ndarray) -> np.ndarray:
s = arr.sum()
return arr / s if s > 0 else arr
def plot_token_importance(
results: Dict[str, np.ndarray], token_names: List[str],
organ: str, out_dir: Path,
) -> Path:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
n_methods = len(results)
fig, axes = plt.subplots(1, n_methods, figsize=(6 * n_methods, max(4, len(token_names) * 0.6)))
if n_methods == 1:
axes = [axes]
color_a, color_b = "#5a448e", "#e07b39"
for ax, (method_name, importance) in zip(axes, results.items()):
normed = normalize(importance)
order = np.argsort(normed)
names_sorted = [token_names[i] for i in order]
vals_sorted = normed[order]
n_tokens = len(token_names)
split_idx = 4 if n_tokens == 8 else 3
channel_a_set = set(token_names[:split_idx])
colors = [color_a if n in channel_a_set else color_b for n in names_sorted]
ax.barh(range(len(names_sorted)), vals_sorted, color=colors)
ax.set_yticks(range(len(names_sorted)))
ax.set_yticklabels(names_sorted, fontsize=11)
ax.set_xlabel("Normalized Importance", fontsize=11)
ax.set_title(f"{method_name}", fontsize=13)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
fig.suptitle(f"Token Importance — biodist: {organ}", fontsize=14, y=1.02)
fig.tight_layout()
out_dir.mkdir(parents=True, exist_ok=True)
fig_path = out_dir / f"token_importance_biodist_{organ}.png"
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
plt.close(fig)
logger.info(f"Figure saved to {fig_path}")
return fig_path
def save_token_csv(
results: Dict[str, np.ndarray], token_names: List[str],
organ: str, out_dir: Path,
gate_vals: Optional[Dict[str, float]] = None,
) -> Path:
out_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame({"token": token_names})
for method_name, importance in results.items():
normed = normalize(importance)
df[f"{method_name}_raw"] = importance
df[f"{method_name}_normalized"] = normed
if gate_vals is not None:
df["gate_sigmoid"] = [gate_vals.get(t, float("nan")) for t in token_names]
df = df.sort_values(
by=[c for c in df.columns if c.endswith("_normalized")][-1],
ascending=False,
)
csv_path = out_dir / f"token_importance_biodist_{organ}.csv"
df.to_csv(csv_path, index=False)
logger.info(f"CSV saved to {csv_path}")
return csv_path
def plot_descriptor_importance(
importance: np.ndarray, feature_names: List[str],
organ: str, out_dir: Path, top_k: int = 30,
) -> Path:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
normed = normalize(importance)
order = np.argsort(-normed)
top_indices = order[:top_k]
names = [feature_names[i] for i in reversed(top_indices)]
vals = [normed[i] for i in reversed(top_indices)]
fig, ax = plt.subplots(figsize=(8, max(5, top_k * 0.35)))
ax.barh(range(len(names)), vals, color="#5a448e")
ax.set_yticks(range(len(names)))
ax.set_yticklabels(names, fontsize=9)
ax.set_xlabel("Normalized Importance", fontsize=11)
ax.set_title(
f"Top-{top_k} desc Feature Importance — biodist: {organ}\n"
f"(total features: {len(feature_names)})",
fontsize=13,
)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
fig.tight_layout()
out_dir.mkdir(parents=True, exist_ok=True)
fig_path = out_dir / f"desc_importance_biodist_{organ}.png"
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
plt.close(fig)
logger.info(f"Descriptor figure saved to {fig_path}")
return fig_path
def save_descriptor_csv(
importance: np.ndarray, feature_names: List[str],
organ: str, out_dir: Path,
) -> Path:
out_dir.mkdir(parents=True, exist_ok=True)
normed = normalize(importance)
df = pd.DataFrame({
"feature": feature_names,
"ig_raw": importance,
"ig_normalized": normed,
})
df = df.sort_values("ig_normalized", ascending=False).reset_index(drop=True)
df.index.name = "rank"
csv_path = out_dir / f"desc_importance_biodist_{organ}.csv"
df.to_csv(csv_path)
logger.info(f"Descriptor CSV saved to {csv_path}")
return csv_path
# ──────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="Biodistribution feature importance (token-level & descriptor-level)")
parser.add_argument("--model-path", type=str,
default=str(MODELS_DIR / "final" / "model.pt"))
parser.add_argument("--data-path", type=str,
default=str(INTERIM_DATA_DIR / "internal.csv"))
parser.add_argument("--organ", type=str, default="all",
choices=BIODIST_ORGANS + ["all"],
help="Organ to analyze (default: all 7 organs)")
parser.add_argument("--method", type=str, default="ig",
choices=["ig", "ablation", "attention", "all"],
help="Token-level method(s)")
parser.add_argument("--desc-ig", action="store_true",
help="Also compute descriptor-level IG within the desc token")
parser.add_argument("--desc-top-k", type=int, default=30,
help="Top-K descriptors to show in plot")
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--n-steps", type=int, default=50,
help="Interpolation steps for Integrated Gradients")
parser.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--output-dir", type=str,
default=str(REPORTS_DIR / "feature_importance"))
args = parser.parse_args()
device = torch.device(args.device)
out_dir = Path(args.output_dir)
# ── Resolve organs ──
organs = BIODIST_ORGANS if args.organ == "all" else [args.organ]
organ_indices = {organ: BIODIST_ORGANS.index(organ) for organ in organs}
logger.info(f"Device: {device}")
logger.info(f"Model: {args.model_path}")
logger.info(f"Data: {args.data_path}")
logger.info(f"Organs: {organs}")
# ── Load model & data ──
model = load_model(Path(args.model_path), device)
token_names = get_token_names(model)
logger.info(f"Tokens ({len(token_names)}): {token_names}")
data_path = Path(args.data_path)
if data_path.suffix == ".csv":
df = pd.read_csv(data_path)
df = process_dataframe(df)
else:
df = pd.read_parquet(data_path)
dataset = LNPDataset(df)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
logger.info(f"Samples: {len(dataset)}")
# ── Pre-compute ──
all_tokens = precompute_projected_tokens(model, loader, device)
logger.info(f"Projected tokens shape: {all_tokens.shape}")
gv = gate_values(model)
logger.info(f"Gate values: {gv}")
methods = ["ig", "ablation", "attention"] if args.method == "all" else [args.method]
# ── Pre-compute desc features if needed ──
raw_desc = None
desc_names = None
if args.desc_ig:
desc_names, desc_dim = get_desc_names_and_dim(model)
raw_desc = precompute_raw_desc(model, loader, device)
logger.info(f"Raw desc shape: {raw_desc.shape} (dim={desc_dim})")
assert raw_desc.size(1) == desc_dim
# ── Per-organ loop ──
for organ, organ_idx in organ_indices.items():
logger.info(f"\n{'#'*60}")
logger.info(f"# Organ: {organ} (index={organ_idx})")
logger.info(f"{'#'*60}")
# Token-level
results: Dict[str, np.ndarray] = {}
for method in methods:
logger.info(f" Computing: {method}")
if method == "ig":
results["Integrated Gradients"] = integrated_gradients_importance(
model, all_tokens, device, organ_idx,
batch_size=args.batch_size, n_steps=args.n_steps,
)
elif method == "ablation":
results["Token Ablation"] = token_ablation_importance(
model, all_tokens, device, organ_idx,
batch_size=args.batch_size,
)
elif method == "attention":
imp = fusion_attention_importance(model, all_tokens, device, args.batch_size)
if imp is not None:
results["Fusion Attention"] = imp
# Print token summary
for method_name, importance in results.items():
normed = normalize(importance)
order = np.argsort(-normed)
logger.info(f"\n {method_name} (biodist_{organ}):")
for rank, idx in enumerate(order, 1):
logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}")
if results:
plot_token_importance(results, token_names, organ, out_dir)
save_token_csv(results, token_names, organ, out_dir, gate_vals=gv)
# Descriptor-level
if args.desc_ig and raw_desc is not None and desc_names is not None:
logger.info(f"\n Descriptor IG for {organ}...")
desc_imp = descriptor_ig_importance(
model, all_tokens, raw_desc, device, organ_idx,
batch_size=args.batch_size, n_steps=args.n_steps,
)
normed = normalize(desc_imp)
top_k = min(args.desc_top_k, len(desc_names))
order = np.argsort(-normed)
logger.info(f"\n Top-{top_k} desc features (biodist_{organ}):")
for rank, idx in enumerate(order[:top_k], 1):
logger.info(f" {rank:>3d}. {desc_names[idx]:<30s} {normed[idx]:.6f}")
desc_out_dir = out_dir / "desc"
plot_descriptor_importance(desc_imp, desc_names, organ, desc_out_dir, top_k)
save_descriptor_csv(desc_imp, desc_names, organ, desc_out_dir)
logger.info("\nDone!")
if __name__ == "__main__":
main()