mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
535 lines
22 KiB
Python
535 lines
22 KiB
Python
"""
|
|
Biodistribution feature importance for the LNP multi-task model.
|
|
|
|
Two levels of analysis, both targeting the biodist head (per organ):
|
|
1. Token-level — which of the 8 input tokens matters most
|
|
2. Descriptor-level — which RDKit descriptors inside the "desc" token matter most
|
|
|
|
Token-level methods:
|
|
- Integrated Gradients (Captum)
|
|
- Token Ablation (zero-out)
|
|
- Fusion Attention Weights
|
|
|
|
Usage:
|
|
python -m lnp_ml.interpretability.token_importance # all organs, token-level IG
|
|
python -m lnp_ml.interpretability.token_importance --organ liver # liver only
|
|
python -m lnp_ml.interpretability.token_importance --desc-ig # + desc-level IG
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from loguru import logger
|
|
from tqdm import tqdm
|
|
|
|
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR, REPORTS_DIR
|
|
from lnp_ml.dataset import LNPDataset, collate_fn, process_dataframe
|
|
from lnp_ml.modeling.predict import load_model
|
|
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
|
|
|
|
|
BIODIST_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
|
|
|
|
|
|
def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]:
|
|
if model.use_mpnn:
|
|
return ["mpnn", "morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
|
return ["morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Pre-compute
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
@torch.no_grad()
|
|
def precompute_projected_tokens(
|
|
model: nn.Module, loader: DataLoader, device: torch.device,
|
|
) -> torch.Tensor:
|
|
"""Returns [N, n_tokens, d_model]."""
|
|
model.eval()
|
|
chunks = []
|
|
for batch in tqdm(loader, desc="Encoding tokens"):
|
|
smiles = batch["smiles"]
|
|
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
|
chunks.append(model._encode_and_project(smiles, tabular).cpu())
|
|
return torch.cat(chunks, dim=0)
|
|
|
|
|
|
@torch.no_grad()
|
|
def precompute_raw_desc(
|
|
model: nn.Module, loader: DataLoader, device: torch.device,
|
|
) -> torch.Tensor:
|
|
"""Returns [N, desc_dim] raw RDKit descriptor features."""
|
|
model.eval()
|
|
chunks = []
|
|
for batch in tqdm(loader, desc="Encoding raw desc"):
|
|
smiles = batch["smiles"]
|
|
rdkit_features = model.rdkit_encoder(smiles)
|
|
target_device = next(model.parameters()).device
|
|
chunks.append(rdkit_features["desc"].to(target_device).cpu())
|
|
return torch.cat(chunks, dim=0)
|
|
|
|
|
|
def get_desc_names_and_dim(model: nn.Module) -> tuple[List[str], int]:
|
|
"""Get descriptor names from RDKit and verify against model dim."""
|
|
from rdkit.Chem import Descriptors as RDKitDesc
|
|
names = [name for name, _ in RDKitDesc.descList]
|
|
model_dim = model.token_projector.projectors["desc"][1].in_features
|
|
if len(names) != model_dim:
|
|
logger.warning(
|
|
f"RDKit descList has {len(names)} names but model expects "
|
|
f"{model_dim} features; using generic names."
|
|
)
|
|
names = [f"desc_{i}" for i in range(model_dim)]
|
|
return names, model_dim
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Token-level: Integrated Gradients
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
class _ProjectedWrapper(nn.Module):
|
|
def __init__(self, model: nn.Module, organ_index: int) -> None:
|
|
super().__init__()
|
|
self.model = model
|
|
self.organ_index = organ_index
|
|
|
|
def forward(self, stacked: torch.Tensor) -> torch.Tensor:
|
|
out = self.model.forward_from_projected(stacked, task="biodist")
|
|
return out[:, self.organ_index]
|
|
|
|
|
|
def integrated_gradients_importance(
|
|
model: nn.Module,
|
|
all_tokens: torch.Tensor,
|
|
device: torch.device,
|
|
organ_index: int,
|
|
batch_size: int = 64,
|
|
n_steps: int = 50,
|
|
) -> np.ndarray:
|
|
"""Returns [n_tokens] averaged absolute attribution per token."""
|
|
from captum.attr import IntegratedGradients
|
|
|
|
wrapper = _ProjectedWrapper(model, organ_index).to(device)
|
|
wrapper.eval()
|
|
ig = IntegratedGradients(wrapper)
|
|
|
|
N = all_tokens.size(0)
|
|
all_attrs: List[torch.Tensor] = []
|
|
|
|
for start in tqdm(range(0, N, batch_size), desc=f"IG (organ={organ_index})"):
|
|
end = min(start + batch_size, N)
|
|
inp = all_tokens[start:end].to(device).requires_grad_(True)
|
|
baseline = torch.zeros_like(inp)
|
|
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
|
|
all_attrs.append(attr.detach().cpu().norm(dim=-1)) # [B, n_tokens]
|
|
|
|
return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Token-level: Ablation
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
@torch.no_grad()
|
|
def token_ablation_importance(
|
|
model: nn.Module,
|
|
all_tokens: torch.Tensor,
|
|
device: torch.device,
|
|
organ_index: int,
|
|
batch_size: int = 64,
|
|
) -> np.ndarray:
|
|
"""Returns [n_tokens] importance via zero-out ablation."""
|
|
model.eval()
|
|
n_tokens = all_tokens.size(1)
|
|
orig_preds = _batch_predict(model, all_tokens, device, organ_index, batch_size)
|
|
|
|
importance = np.zeros(n_tokens)
|
|
for t in range(n_tokens):
|
|
ablated = all_tokens.clone()
|
|
ablated[:, t, :] = 0.0
|
|
abl_preds = _batch_predict(model, ablated, device, organ_index, batch_size)
|
|
importance[t] = np.abs(orig_preds - abl_preds).mean()
|
|
|
|
return importance
|
|
|
|
|
|
def _batch_predict(
|
|
model: nn.Module, all_tokens: torch.Tensor,
|
|
device: torch.device, organ_index: int, batch_size: int,
|
|
) -> np.ndarray:
|
|
model.eval()
|
|
preds = []
|
|
for start in range(0, all_tokens.size(0), batch_size):
|
|
end = min(start + batch_size, all_tokens.size(0))
|
|
out = model.forward_from_projected(all_tokens[start:end].to(device), task="biodist")
|
|
preds.append(out[:, organ_index].cpu().numpy())
|
|
return np.concatenate(preds)
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Token-level: Fusion Attention Weights (task-agnostic)
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
@torch.no_grad()
|
|
def fusion_attention_importance(
|
|
model: nn.Module, all_tokens: torch.Tensor,
|
|
device: torch.device, batch_size: int = 64,
|
|
) -> Optional[np.ndarray]:
|
|
model.eval()
|
|
if model.fusion.strategy != "attention":
|
|
logger.warning("Fusion strategy is not 'attention'; skipping.")
|
|
return None
|
|
|
|
all_weights: List[torch.Tensor] = []
|
|
for start in range(0, all_tokens.size(0), batch_size):
|
|
end = min(start + batch_size, all_tokens.size(0))
|
|
attended = model.cross_attention(all_tokens[start:end].to(device))
|
|
_, weights = model.fusion(attended, return_attn_weights=True)
|
|
all_weights.append(weights.cpu())
|
|
|
|
return torch.cat(all_weights, dim=0).mean(dim=0).numpy()
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Descriptor-level: Integrated Gradients
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
class _ReplacingDescWrapper(nn.Module):
|
|
def __init__(self, model: nn.Module, base_projected: torch.Tensor, organ_index: int) -> None:
|
|
super().__init__()
|
|
self.model = model
|
|
self.base_projected = base_projected
|
|
self.organ_index = organ_index
|
|
|
|
def forward(self, raw_desc: torch.Tensor) -> torch.Tensor:
|
|
B_base = self.base_projected.size(0)
|
|
B_input = raw_desc.size(0)
|
|
if B_input > B_base:
|
|
repeats = B_input // B_base
|
|
base = self.base_projected.repeat(repeats, 1, 1).to(raw_desc.device)
|
|
else:
|
|
base = self.base_projected[:B_input].to(raw_desc.device)
|
|
out = self.model.forward_replacing_token(raw_desc, "desc", base, task="biodist")
|
|
return out[:, self.organ_index]
|
|
|
|
|
|
def descriptor_ig_importance(
|
|
model: nn.Module,
|
|
all_tokens: torch.Tensor,
|
|
raw_desc: torch.Tensor,
|
|
device: torch.device,
|
|
organ_index: int,
|
|
batch_size: int = 64,
|
|
n_steps: int = 50,
|
|
) -> np.ndarray:
|
|
"""Returns [desc_dim] averaged absolute attribution per descriptor."""
|
|
from captum.attr import IntegratedGradients
|
|
|
|
wrapper = _ReplacingDescWrapper(model, all_tokens, organ_index).to(device)
|
|
wrapper.eval()
|
|
ig = IntegratedGradients(wrapper)
|
|
|
|
N = raw_desc.size(0)
|
|
all_attrs: List[torch.Tensor] = []
|
|
|
|
for start in tqdm(range(0, N, batch_size), desc=f"Desc-IG (organ={organ_index})"):
|
|
end = min(start + batch_size, N)
|
|
inp = raw_desc[start:end].to(device).requires_grad_(True)
|
|
baseline = torch.zeros_like(inp)
|
|
wrapper.base_projected = all_tokens[start:end]
|
|
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
|
|
all_attrs.append(attr.detach().cpu().abs())
|
|
|
|
return torch.cat(all_attrs, dim=0).mean(dim=0).numpy()
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Gate values
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
def gate_values(model: nn.Module) -> Dict[str, float]:
|
|
gates = {}
|
|
for key in model.token_projector.keys:
|
|
w = model.token_projector.weights[key].detach().cpu()
|
|
gates[key] = float(torch.sigmoid(w).item())
|
|
return gates
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Visualization & IO
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
def normalize(arr: np.ndarray) -> np.ndarray:
|
|
s = arr.sum()
|
|
return arr / s if s > 0 else arr
|
|
|
|
|
|
def plot_token_importance(
|
|
results: Dict[str, np.ndarray], token_names: List[str],
|
|
organ: str, out_dir: Path,
|
|
) -> Path:
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
n_methods = len(results)
|
|
fig, axes = plt.subplots(1, n_methods, figsize=(6 * n_methods, max(4, len(token_names) * 0.6)))
|
|
if n_methods == 1:
|
|
axes = [axes]
|
|
|
|
color_a, color_b = "#5a448e", "#e07b39"
|
|
|
|
for ax, (method_name, importance) in zip(axes, results.items()):
|
|
normed = normalize(importance)
|
|
order = np.argsort(normed)
|
|
names_sorted = [token_names[i] for i in order]
|
|
vals_sorted = normed[order]
|
|
|
|
n_tokens = len(token_names)
|
|
split_idx = 4 if n_tokens == 8 else 3
|
|
channel_a_set = set(token_names[:split_idx])
|
|
colors = [color_a if n in channel_a_set else color_b for n in names_sorted]
|
|
|
|
ax.barh(range(len(names_sorted)), vals_sorted, color=colors)
|
|
ax.set_yticks(range(len(names_sorted)))
|
|
ax.set_yticklabels(names_sorted, fontsize=11)
|
|
ax.set_xlabel("Normalized Importance", fontsize=11)
|
|
ax.set_title(f"{method_name}", fontsize=13)
|
|
ax.spines["top"].set_visible(False)
|
|
ax.spines["right"].set_visible(False)
|
|
|
|
fig.suptitle(f"Token Importance — biodist: {organ}", fontsize=14, y=1.02)
|
|
fig.tight_layout()
|
|
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
fig_path = out_dir / f"token_importance_biodist_{organ}.png"
|
|
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
|
|
plt.close(fig)
|
|
logger.info(f"Figure saved to {fig_path}")
|
|
return fig_path
|
|
|
|
|
|
def save_token_csv(
|
|
results: Dict[str, np.ndarray], token_names: List[str],
|
|
organ: str, out_dir: Path,
|
|
gate_vals: Optional[Dict[str, float]] = None,
|
|
) -> Path:
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
df = pd.DataFrame({"token": token_names})
|
|
for method_name, importance in results.items():
|
|
normed = normalize(importance)
|
|
df[f"{method_name}_raw"] = importance
|
|
df[f"{method_name}_normalized"] = normed
|
|
if gate_vals is not None:
|
|
df["gate_sigmoid"] = [gate_vals.get(t, float("nan")) for t in token_names]
|
|
df = df.sort_values(
|
|
by=[c for c in df.columns if c.endswith("_normalized")][-1],
|
|
ascending=False,
|
|
)
|
|
csv_path = out_dir / f"token_importance_biodist_{organ}.csv"
|
|
df.to_csv(csv_path, index=False)
|
|
logger.info(f"CSV saved to {csv_path}")
|
|
return csv_path
|
|
|
|
|
|
def plot_descriptor_importance(
|
|
importance: np.ndarray, feature_names: List[str],
|
|
organ: str, out_dir: Path, top_k: int = 30,
|
|
) -> Path:
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
normed = normalize(importance)
|
|
order = np.argsort(-normed)
|
|
top_indices = order[:top_k]
|
|
|
|
names = [feature_names[i] for i in reversed(top_indices)]
|
|
vals = [normed[i] for i in reversed(top_indices)]
|
|
|
|
fig, ax = plt.subplots(figsize=(8, max(5, top_k * 0.35)))
|
|
ax.barh(range(len(names)), vals, color="#5a448e")
|
|
ax.set_yticks(range(len(names)))
|
|
ax.set_yticklabels(names, fontsize=9)
|
|
ax.set_xlabel("Normalized Importance", fontsize=11)
|
|
ax.set_title(
|
|
f"Top-{top_k} desc Feature Importance — biodist: {organ}\n"
|
|
f"(total features: {len(feature_names)})",
|
|
fontsize=13,
|
|
)
|
|
ax.spines["top"].set_visible(False)
|
|
ax.spines["right"].set_visible(False)
|
|
fig.tight_layout()
|
|
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
fig_path = out_dir / f"desc_importance_biodist_{organ}.png"
|
|
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
|
|
plt.close(fig)
|
|
logger.info(f"Descriptor figure saved to {fig_path}")
|
|
return fig_path
|
|
|
|
|
|
def save_descriptor_csv(
|
|
importance: np.ndarray, feature_names: List[str],
|
|
organ: str, out_dir: Path,
|
|
) -> Path:
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
normed = normalize(importance)
|
|
df = pd.DataFrame({
|
|
"feature": feature_names,
|
|
"ig_raw": importance,
|
|
"ig_normalized": normed,
|
|
})
|
|
df = df.sort_values("ig_normalized", ascending=False).reset_index(drop=True)
|
|
df.index.name = "rank"
|
|
csv_path = out_dir / f"desc_importance_biodist_{organ}.csv"
|
|
df.to_csv(csv_path)
|
|
logger.info(f"Descriptor CSV saved to {csv_path}")
|
|
return csv_path
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Main
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Biodistribution feature importance (token-level & descriptor-level)")
|
|
parser.add_argument("--model-path", type=str,
|
|
default=str(MODELS_DIR / "final" / "model.pt"))
|
|
parser.add_argument("--data-path", type=str,
|
|
default=str(INTERIM_DATA_DIR / "internal.csv"))
|
|
parser.add_argument("--organ", type=str, default="all",
|
|
choices=BIODIST_ORGANS + ["all"],
|
|
help="Organ to analyze (default: all 7 organs)")
|
|
parser.add_argument("--method", type=str, default="ig",
|
|
choices=["ig", "ablation", "attention", "all"],
|
|
help="Token-level method(s)")
|
|
parser.add_argument("--desc-ig", action="store_true",
|
|
help="Also compute descriptor-level IG within the desc token")
|
|
parser.add_argument("--desc-top-k", type=int, default=30,
|
|
help="Top-K descriptors to show in plot")
|
|
parser.add_argument("--batch-size", type=int, default=64)
|
|
parser.add_argument("--n-steps", type=int, default=50,
|
|
help="Interpolation steps for Integrated Gradients")
|
|
parser.add_argument("--device", type=str,
|
|
default="cuda" if torch.cuda.is_available() else "cpu")
|
|
parser.add_argument("--output-dir", type=str,
|
|
default=str(REPORTS_DIR / "feature_importance"))
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device(args.device)
|
|
out_dir = Path(args.output_dir)
|
|
|
|
# ── Resolve organs ──
|
|
organs = BIODIST_ORGANS if args.organ == "all" else [args.organ]
|
|
organ_indices = {organ: BIODIST_ORGANS.index(organ) for organ in organs}
|
|
|
|
logger.info(f"Device: {device}")
|
|
logger.info(f"Model: {args.model_path}")
|
|
logger.info(f"Data: {args.data_path}")
|
|
logger.info(f"Organs: {organs}")
|
|
|
|
# ── Load model & data ──
|
|
model = load_model(Path(args.model_path), device)
|
|
token_names = get_token_names(model)
|
|
logger.info(f"Tokens ({len(token_names)}): {token_names}")
|
|
|
|
data_path = Path(args.data_path)
|
|
if data_path.suffix == ".csv":
|
|
df = pd.read_csv(data_path)
|
|
df = process_dataframe(df)
|
|
else:
|
|
df = pd.read_parquet(data_path)
|
|
dataset = LNPDataset(df)
|
|
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
|
logger.info(f"Samples: {len(dataset)}")
|
|
|
|
# ── Pre-compute ──
|
|
all_tokens = precompute_projected_tokens(model, loader, device)
|
|
logger.info(f"Projected tokens shape: {all_tokens.shape}")
|
|
|
|
gv = gate_values(model)
|
|
logger.info(f"Gate values: {gv}")
|
|
|
|
methods = ["ig", "ablation", "attention"] if args.method == "all" else [args.method]
|
|
|
|
# ── Pre-compute desc features if needed ──
|
|
raw_desc = None
|
|
desc_names = None
|
|
if args.desc_ig:
|
|
desc_names, desc_dim = get_desc_names_and_dim(model)
|
|
raw_desc = precompute_raw_desc(model, loader, device)
|
|
logger.info(f"Raw desc shape: {raw_desc.shape} (dim={desc_dim})")
|
|
assert raw_desc.size(1) == desc_dim
|
|
|
|
# ── Per-organ loop ──
|
|
for organ, organ_idx in organ_indices.items():
|
|
logger.info(f"\n{'#'*60}")
|
|
logger.info(f"# Organ: {organ} (index={organ_idx})")
|
|
logger.info(f"{'#'*60}")
|
|
|
|
# Token-level
|
|
results: Dict[str, np.ndarray] = {}
|
|
for method in methods:
|
|
logger.info(f" Computing: {method}")
|
|
if method == "ig":
|
|
results["Integrated Gradients"] = integrated_gradients_importance(
|
|
model, all_tokens, device, organ_idx,
|
|
batch_size=args.batch_size, n_steps=args.n_steps,
|
|
)
|
|
elif method == "ablation":
|
|
results["Token Ablation"] = token_ablation_importance(
|
|
model, all_tokens, device, organ_idx,
|
|
batch_size=args.batch_size,
|
|
)
|
|
elif method == "attention":
|
|
imp = fusion_attention_importance(model, all_tokens, device, args.batch_size)
|
|
if imp is not None:
|
|
results["Fusion Attention"] = imp
|
|
|
|
# Print token summary
|
|
for method_name, importance in results.items():
|
|
normed = normalize(importance)
|
|
order = np.argsort(-normed)
|
|
logger.info(f"\n {method_name} (biodist_{organ}):")
|
|
for rank, idx in enumerate(order, 1):
|
|
logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}")
|
|
|
|
if results:
|
|
plot_token_importance(results, token_names, organ, out_dir)
|
|
save_token_csv(results, token_names, organ, out_dir, gate_vals=gv)
|
|
|
|
# Descriptor-level
|
|
if args.desc_ig and raw_desc is not None and desc_names is not None:
|
|
logger.info(f"\n Descriptor IG for {organ}...")
|
|
desc_imp = descriptor_ig_importance(
|
|
model, all_tokens, raw_desc, device, organ_idx,
|
|
batch_size=args.batch_size, n_steps=args.n_steps,
|
|
)
|
|
normed = normalize(desc_imp)
|
|
top_k = min(args.desc_top_k, len(desc_names))
|
|
order = np.argsort(-normed)
|
|
logger.info(f"\n Top-{top_k} desc features (biodist_{organ}):")
|
|
for rank, idx in enumerate(order[:top_k], 1):
|
|
logger.info(f" {rank:>3d}. {desc_names[idx]:<30s} {normed[idx]:.6f}")
|
|
|
|
desc_out_dir = out_dir / "desc"
|
|
plot_descriptor_importance(desc_imp, desc_names, organ, desc_out_dir, top_k)
|
|
save_descriptor_csv(desc_imp, desc_names, organ, desc_out_dir)
|
|
|
|
logger.info("\nDone!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|