mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 01:27:00 +08:00
Add feature importance
This commit is contained in:
parent
447f2543f7
commit
25938df4ce
20
Makefile
20
Makefile
@ -122,6 +122,26 @@ train: requirements
|
||||
$(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \
|
||||
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG)
|
||||
|
||||
#################################################################################
|
||||
# INTERPRETABILITY #
|
||||
#################################################################################
|
||||
# 参数:
|
||||
# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic; 默认: delivery)
|
||||
# METHOD 方法 (ig, ablation, attention, all; 默认: all)
|
||||
# DATA 数据路径 (默认: data/processed/train.parquet)
|
||||
# MODEL 模型路径 (默认: models/model.pt)
|
||||
|
||||
TASK_FLAG = $(if $(TASK),--task $(TASK),)
|
||||
METHOD_FLAG = $(if $(METHOD),--method $(METHOD),)
|
||||
DATA_FLAG = $(if $(DATA),--data-path $(DATA),)
|
||||
MODEL_FLAG = $(if $(MODEL),--model-path $(MODEL),)
|
||||
|
||||
## Compute token-level feature importance
|
||||
.PHONY: feature_importance
|
||||
feature_importance: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \
|
||||
$(TASK_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
#################################################################################
|
||||
# SERVING & DEPLOYMENT #
|
||||
#################################################################################
|
||||
|
||||
0
lnp_ml/interpretability/__init__.py
Normal file
0
lnp_ml/interpretability/__init__.py
Normal file
438
lnp_ml/interpretability/token_importance.py
Normal file
438
lnp_ml/interpretability/token_importance.py
Normal file
@ -0,0 +1,438 @@
|
||||
"""
|
||||
Token-level feature importance for the LNP multi-task model.
|
||||
|
||||
Three complementary methods (inspired by AGILE / Captum):
|
||||
1. Integrated Gradients — gradient-based attribution via Captum
|
||||
2. Token Ablation — zero-out each token and measure prediction change
|
||||
3. Fusion Attention Weights — extract learned attention pooling weights
|
||||
|
||||
Usage:
|
||||
python -m lnp_ml.interpretability.token_importance \
|
||||
--model-path models/model.pt \
|
||||
--data-path data/processed/train.parquet \
|
||||
--task delivery \
|
||||
--method all
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
|
||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR, REPORTS_DIR
|
||||
from lnp_ml.dataset import LNPDataset, collate_fn
|
||||
from lnp_ml.modeling.predict import load_model
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
|
||||
|
||||
TASKS = ["delivery", "size", "pdi", "ee", "biodist", "toxic"]
|
||||
|
||||
|
||||
def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]:
|
||||
if model.use_mpnn:
|
||||
return ["mpnn", "morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
||||
return ["morgan", "maccs", "desc", "comp", "phys", "help", "exp"]
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Helper: pre-compute projected tokens for the whole dataset
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def precompute_projected_tokens(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Run encoder + TokenProjector on all samples and stack results.
|
||||
|
||||
Returns:
|
||||
all_tokens: [N, n_tokens, d_model]
|
||||
"""
|
||||
model.eval()
|
||||
chunks = []
|
||||
for batch in tqdm(loader, desc="Encoding tokens"):
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
stacked = model._encode_and_project(smiles, tabular)
|
||||
chunks.append(stacked.cpu())
|
||||
return torch.cat(chunks, dim=0)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Method 1: Integrated Gradients (Captum)
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
class _ProjectedWrapper(nn.Module):
|
||||
"""Wraps model.forward_from_projected so Captum can attribute w.r.t. stacked tokens."""
|
||||
|
||||
def __init__(self, model: nn.Module, task: str) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.task = task
|
||||
|
||||
def forward(self, stacked: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
stacked: [B, n_tokens, d_model]
|
||||
Returns:
|
||||
[B] scalar predictions (sum over output dims for non-scalar heads)
|
||||
"""
|
||||
out = self.model.forward_from_projected(stacked, task=self.task)
|
||||
if out.dim() > 1 and out.size(-1) > 1:
|
||||
return out.sum(dim=-1)
|
||||
return out.squeeze(-1)
|
||||
|
||||
|
||||
def integrated_gradients_importance(
|
||||
model: nn.Module,
|
||||
all_tokens: torch.Tensor,
|
||||
device: torch.device,
|
||||
task: str = "delivery",
|
||||
batch_size: int = 64,
|
||||
n_steps: int = 50,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute token-level Integrated Gradients.
|
||||
|
||||
Returns:
|
||||
importance: [n_tokens] averaged absolute attribution per token position
|
||||
"""
|
||||
from captum.attr import IntegratedGradients
|
||||
|
||||
wrapper = _ProjectedWrapper(model, task).to(device)
|
||||
wrapper.eval()
|
||||
ig = IntegratedGradients(wrapper)
|
||||
|
||||
N = all_tokens.size(0)
|
||||
all_attrs: List[torch.Tensor] = []
|
||||
|
||||
for start in tqdm(range(0, N, batch_size), desc=f"IG ({task})"):
|
||||
end = min(start + batch_size, N)
|
||||
inp = all_tokens[start:end].to(device).requires_grad_(True)
|
||||
baseline = torch.zeros_like(inp)
|
||||
|
||||
attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps)
|
||||
# attr: [B, n_tokens, d_model] → per-token L2 norm → [B, n_tokens]
|
||||
token_attr = attr.detach().cpu().norm(dim=-1)
|
||||
all_attrs.append(token_attr)
|
||||
|
||||
attrs = torch.cat(all_attrs, dim=0) # [N, n_tokens]
|
||||
importance = attrs.mean(dim=0).numpy() # [n_tokens]
|
||||
return importance
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Method 2: Token Ablation (zero-out)
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def token_ablation_importance(
|
||||
model: nn.Module,
|
||||
all_tokens: torch.Tensor,
|
||||
device: torch.device,
|
||||
task: str = "delivery",
|
||||
batch_size: int = 64,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
For each token position, replace it with zeros and measure
|
||||
the average absolute prediction change.
|
||||
|
||||
Returns:
|
||||
importance: [n_tokens]
|
||||
"""
|
||||
model.eval()
|
||||
n_tokens = all_tokens.size(1)
|
||||
N = all_tokens.size(0)
|
||||
|
||||
# original predictions
|
||||
orig_preds = _batch_predict(model, all_tokens, device, task, batch_size)
|
||||
|
||||
importance = np.zeros(n_tokens)
|
||||
for t in range(n_tokens):
|
||||
ablated = all_tokens.clone()
|
||||
ablated[:, t, :] = 0.0
|
||||
abl_preds = _batch_predict(model, ablated, device, task, batch_size)
|
||||
importance[t] = np.abs(orig_preds - abl_preds).mean()
|
||||
|
||||
return importance
|
||||
|
||||
|
||||
def _batch_predict(
|
||||
model: nn.Module,
|
||||
all_tokens: torch.Tensor,
|
||||
device: torch.device,
|
||||
task: str,
|
||||
batch_size: int,
|
||||
) -> np.ndarray:
|
||||
"""Run forward_from_projected in batches, return [N] predictions."""
|
||||
model.eval()
|
||||
preds = []
|
||||
N = all_tokens.size(0)
|
||||
for start in range(0, N, batch_size):
|
||||
end = min(start + batch_size, N)
|
||||
inp = all_tokens[start:end].to(device)
|
||||
out = model.forward_from_projected(inp, task=task)
|
||||
if out.dim() > 1 and out.size(-1) > 1:
|
||||
out = out.sum(dim=-1)
|
||||
else:
|
||||
out = out.squeeze(-1)
|
||||
preds.append(out.cpu().numpy())
|
||||
return np.concatenate(preds, axis=0)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Method 3: Fusion Attention Weights
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@torch.no_grad()
|
||||
def fusion_attention_importance(
|
||||
model: nn.Module,
|
||||
all_tokens: torch.Tensor,
|
||||
device: torch.device,
|
||||
batch_size: int = 64,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Extract FusionLayer attention weights (only works if strategy="attention").
|
||||
|
||||
Returns:
|
||||
importance: [n_tokens] averaged attention weights, or None if not applicable.
|
||||
"""
|
||||
model.eval()
|
||||
if model.fusion.strategy != "attention":
|
||||
logger.warning("Fusion strategy is not 'attention'; skipping attention weight extraction.")
|
||||
return None
|
||||
|
||||
N = all_tokens.size(0)
|
||||
all_weights: List[torch.Tensor] = []
|
||||
|
||||
for start in range(0, N, batch_size):
|
||||
end = min(start + batch_size, N)
|
||||
inp = all_tokens[start:end].to(device)
|
||||
attended = model.cross_attention(inp)
|
||||
_, weights = model.fusion(attended, return_attn_weights=True)
|
||||
all_weights.append(weights.cpu())
|
||||
|
||||
weights = torch.cat(all_weights, dim=0) # [N, n_tokens]
|
||||
return weights.mean(dim=0).numpy()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Bonus: TokenProjector gate values (static)
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def gate_values(model: nn.Module) -> Dict[str, float]:
|
||||
"""Read sigmoid(weight) from TokenProjector for each token."""
|
||||
gates = {}
|
||||
for key in model.token_projector.keys:
|
||||
w = model.token_projector.weights[key].detach().cpu()
|
||||
gates[key] = float(torch.sigmoid(w).item())
|
||||
return gates
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Visualization
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def normalize(arr: np.ndarray) -> np.ndarray:
|
||||
s = arr.sum()
|
||||
return arr / s if s > 0 else arr
|
||||
|
||||
|
||||
def plot_token_importance(
|
||||
results: Dict[str, np.ndarray],
|
||||
token_names: List[str],
|
||||
task: str,
|
||||
out_dir: Path,
|
||||
) -> Path:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
n_methods = len(results)
|
||||
fig, axes = plt.subplots(1, n_methods, figsize=(6 * n_methods, max(4, len(token_names) * 0.6)))
|
||||
if n_methods == 1:
|
||||
axes = [axes]
|
||||
|
||||
channel_a_color = "#5a448e"
|
||||
channel_b_color = "#e07b39"
|
||||
|
||||
for ax, (method_name, importance) in zip(axes, results.items()):
|
||||
normed = normalize(importance)
|
||||
order = np.argsort(normed)
|
||||
|
||||
names_sorted = [token_names[i] for i in order]
|
||||
vals_sorted = normed[order]
|
||||
|
||||
n_tokens = len(token_names)
|
||||
split_idx = 4 if n_tokens == 8 else 3
|
||||
channel_a_set = set(token_names[:split_idx])
|
||||
colors = [channel_a_color if n in channel_a_set else channel_b_color for n in names_sorted]
|
||||
|
||||
ax.barh(range(len(names_sorted)), vals_sorted, color=colors)
|
||||
ax.set_yticks(range(len(names_sorted)))
|
||||
ax.set_yticklabels(names_sorted, fontsize=11)
|
||||
ax.set_xlabel("Normalized Importance", fontsize=11)
|
||||
ax.set_title(f"{method_name}", fontsize=13)
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
|
||||
fig.suptitle(f"Token Importance — task: {task}", fontsize=14, y=1.02)
|
||||
fig.tight_layout()
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
fig_path = out_dir / f"token_importance_{task}.png"
|
||||
fig.savefig(fig_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
logger.info(f"Figure saved to {fig_path}")
|
||||
return fig_path
|
||||
|
||||
|
||||
def save_csv(
|
||||
results: Dict[str, np.ndarray],
|
||||
token_names: List[str],
|
||||
task: str,
|
||||
out_dir: Path,
|
||||
gate_vals: Optional[Dict[str, float]] = None,
|
||||
) -> Path:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
df = pd.DataFrame({"token": token_names})
|
||||
for method_name, importance in results.items():
|
||||
normed = normalize(importance)
|
||||
df[f"{method_name}_raw"] = importance
|
||||
df[f"{method_name}_normalized"] = normed
|
||||
|
||||
if gate_vals is not None:
|
||||
df["gate_sigmoid"] = [gate_vals.get(t, float("nan")) for t in token_names]
|
||||
|
||||
df = df.sort_values(
|
||||
by=[c for c in df.columns if c.endswith("_normalized")][-1],
|
||||
ascending=False,
|
||||
)
|
||||
csv_path = out_dir / f"token_importance_{task}.csv"
|
||||
df.to_csv(csv_path, index=False)
|
||||
logger.info(f"CSV saved to {csv_path}")
|
||||
return csv_path
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Main
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Token-level Feature Importance")
|
||||
parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "model.pt"),
|
||||
help="Path to trained model checkpoint")
|
||||
parser.add_argument("--data-path", type=str, default=str(PROCESSED_DATA_DIR / "train.parquet"),
|
||||
help="Path to data (parquet) for computing importance")
|
||||
parser.add_argument("--task", type=str, default="delivery", choices=TASKS,
|
||||
help="Target task for importance computation")
|
||||
parser.add_argument("--method", type=str, default="ig",
|
||||
choices=["ig", "ablation", "attention", "all"],
|
||||
help="Which method(s) to run")
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument("--n-steps", type=int, default=50,
|
||||
help="Number of interpolation steps for Integrated Gradients")
|
||||
parser.add_argument("--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
parser.add_argument("--output-dir", type=str,
|
||||
default=str(REPORTS_DIR / "feature_importance"))
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device(args.device)
|
||||
out_dir = Path(args.output_dir)
|
||||
logger.info(f"Device: {device}")
|
||||
logger.info(f"Model: {args.model_path}")
|
||||
logger.info(f"Data: {args.data_path}")
|
||||
logger.info(f"Task: {args.task}")
|
||||
|
||||
# ── Load model ──
|
||||
model = load_model(Path(args.model_path), device)
|
||||
token_names = get_token_names(model)
|
||||
logger.info(f"Tokens ({len(token_names)}): {token_names}")
|
||||
|
||||
# ── Load data ──
|
||||
df = pd.read_parquet(args.data_path)
|
||||
dataset = LNPDataset(df)
|
||||
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
||||
logger.info(f"Samples: {len(dataset)}")
|
||||
|
||||
# ── Pre-compute projected tokens ──
|
||||
all_tokens = precompute_projected_tokens(model, loader, device)
|
||||
logger.info(f"Projected tokens shape: {all_tokens.shape}")
|
||||
|
||||
# ── Gate values (always) ──
|
||||
gv = gate_values(model)
|
||||
logger.info(f"TokenProjector gate values: {gv}")
|
||||
|
||||
# ── Run selected methods ──
|
||||
methods = (
|
||||
["ig", "ablation", "attention"] if args.method == "all"
|
||||
else [args.method]
|
||||
)
|
||||
|
||||
results: Dict[str, np.ndarray] = {}
|
||||
|
||||
for method in methods:
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Computing: {method}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
if method == "ig":
|
||||
imp = integrated_gradients_importance(
|
||||
model, all_tokens, device,
|
||||
task=args.task, batch_size=args.batch_size, n_steps=args.n_steps,
|
||||
)
|
||||
results["Integrated Gradients"] = imp
|
||||
|
||||
elif method == "ablation":
|
||||
imp = token_ablation_importance(
|
||||
model, all_tokens, device,
|
||||
task=args.task, batch_size=args.batch_size,
|
||||
)
|
||||
results["Token Ablation"] = imp
|
||||
|
||||
elif method == "attention":
|
||||
imp = fusion_attention_importance(
|
||||
model, all_tokens, device,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
if imp is not None:
|
||||
results["Fusion Attention"] = imp
|
||||
|
||||
# ── Print summary ──
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Token Importance Summary (task={args.task})")
|
||||
logger.info(f"{'='*60}")
|
||||
for method_name, importance in results.items():
|
||||
normed = normalize(importance)
|
||||
order = np.argsort(-normed)
|
||||
logger.info(f"\n {method_name}:")
|
||||
for rank, idx in enumerate(order, 1):
|
||||
logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}")
|
||||
|
||||
logger.info(f"\n Gate values (sigmoid):")
|
||||
for name, val in sorted(gv.items(), key=lambda x: -x[1]):
|
||||
logger.info(f" {name:<10s} {val:.4f}")
|
||||
|
||||
# ── Save results ──
|
||||
if results:
|
||||
plot_token_importance(results, token_names, args.task, out_dir)
|
||||
save_csv(results, token_names, args.task, out_dir, gate_vals=gv)
|
||||
|
||||
logger.info("\nDone!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Literal, Union
|
||||
from typing import Dict, Literal, Tuple, Union
|
||||
|
||||
|
||||
PoolingStrategy = Literal["concat", "avg", "max", "attention"]
|
||||
@ -48,52 +48,64 @@ class FusionLayer(nn.Module):
|
||||
self.attn_query = nn.Parameter(torch.randn(1, 1, d_model))
|
||||
self.attn_proj = nn.Linear(d_model, d_model)
|
||||
|
||||
def forward(self, x: Union[Dict[str, torch.Tensor], torch.Tensor]) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
x: Union[Dict[str, torch.Tensor], torch.Tensor],
|
||||
return_attn_weights: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x: Dict[str, Tensor] 每个 [B, d_model],或已 stack 的 [B, n_tokens, d_model]
|
||||
return_attn_weights: 若为 True 且策略为 attention,额外返回 attn_weights [B, n_tokens]
|
||||
|
||||
Returns:
|
||||
[B, fusion_dim]
|
||||
return_attn_weights=False: [B, fusion_dim]
|
||||
return_attn_weights=True: ([B, fusion_dim], [B, n_tokens])
|
||||
"""
|
||||
# 如果输入是 dict,先 stack
|
||||
if isinstance(x, dict):
|
||||
x = torch.stack(list(x.values()), dim=1) # [B, n_tokens, d_model]
|
||||
x = torch.stack(list(x.values()), dim=1)
|
||||
|
||||
if self.strategy == "concat":
|
||||
return x.flatten(start_dim=1) # [B, n_tokens * d_model]
|
||||
out = x.flatten(start_dim=1)
|
||||
return (out, None) if return_attn_weights else out
|
||||
|
||||
elif self.strategy == "avg":
|
||||
return x.mean(dim=1) # [B, d_model]
|
||||
out = x.mean(dim=1)
|
||||
return (out, None) if return_attn_weights else out
|
||||
|
||||
elif self.strategy == "max":
|
||||
return x.max(dim=1).values # [B, d_model]
|
||||
out = x.max(dim=1).values
|
||||
return (out, None) if return_attn_weights else out
|
||||
|
||||
elif self.strategy == "attention":
|
||||
return self._attention_pooling(x)
|
||||
return self._attention_pooling(x, return_attn_weights)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {self.strategy}")
|
||||
|
||||
def _attention_pooling(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def _attention_pooling(
|
||||
self, x: torch.Tensor, return_attn_weights: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Attention pooling: 用可学习 query 对 tokens 做加权求和
|
||||
|
||||
Args:
|
||||
x: [B, n_tokens, d_model]
|
||||
return_attn_weights: 是否返回权重
|
||||
|
||||
Returns:
|
||||
[B, d_model]
|
||||
return_attn_weights=False: [B, d_model]
|
||||
return_attn_weights=True: ([B, d_model], [B, n_tokens])
|
||||
"""
|
||||
B = x.size(0)
|
||||
# query: [1, 1, d_model] -> [B, 1, d_model]
|
||||
query = self.attn_query.expand(B, -1, -1)
|
||||
|
||||
# Attention scores: [B, 1, n_tokens]
|
||||
keys = self.attn_proj(x) # [B, n_tokens, d_model]
|
||||
keys = self.attn_proj(x)
|
||||
scores = torch.bmm(query, keys.transpose(1, 2)) / (self.d_model ** 0.5)
|
||||
attn_weights = F.softmax(scores, dim=-1) # [B, 1, n_tokens]
|
||||
|
||||
# Weighted sum: [B, 1, d_model] -> [B, d_model]
|
||||
out = torch.bmm(attn_weights, x).squeeze(1)
|
||||
out = torch.bmm(attn_weights, x).squeeze(1) # [B, d_model]
|
||||
|
||||
if return_attn_weights:
|
||||
return out, attn_weights.squeeze(1) # [B, n_tokens]
|
||||
return out
|
||||
|
||||
@ -169,6 +169,38 @@ class LNPModel(nn.Module):
|
||||
stacked = torch.stack([projected[k] for k in token_order], dim=1)
|
||||
return stacked
|
||||
|
||||
def forward_from_projected(
|
||||
self,
|
||||
stacked: torch.Tensor,
|
||||
task: Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
从已投影的 stacked tokens 开始 forward,用于 Captum 归因。
|
||||
|
||||
Args:
|
||||
stacked: [B, n_tokens, d_model] TokenProjector 输出后 stack 的张量
|
||||
task: 指定单任务名 ("size", "pdi", "ee", "delivery", "biodist", "toxic")。
|
||||
若为 None,返回 delivery head 的标量输出。
|
||||
|
||||
Returns:
|
||||
[B, 1] 或 [B, num_classes] 对应任务的预测输出
|
||||
"""
|
||||
attended = self.cross_attention(stacked)
|
||||
fused = self.fusion(attended)
|
||||
|
||||
if task is None:
|
||||
task = "delivery"
|
||||
|
||||
task_heads = {
|
||||
"size": self.head.size_head,
|
||||
"pdi": self.head.pdi_head,
|
||||
"ee": self.head.ee_head,
|
||||
"delivery": self.head.delivery_head,
|
||||
"biodist": self.head.biodist_head,
|
||||
"toxic": self.head.toxic_head,
|
||||
}
|
||||
return task_heads[task](fused)
|
||||
|
||||
def forward_backbone(
|
||||
self,
|
||||
smiles: List[str],
|
||||
|
||||
45
pixi.lock
45
pixi.lock
@ -62,6 +62,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/bb/2a/10164ed1f31196a2f7f3799368a821765c62851ead0e630ab52b8e14b4d0/blinker-1.8.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e1/76/b21bfd2c35cab2e9a4b68b1977f7488c246c8cffa31e3361ee7610e8b5af/captum-0.7.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/bf/fa/cf5bb2409a385f78750e78c8d2e24780964976acdaaed65dbd6083ae5b40/charset_normalizer-3.4.4-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl
|
||||
@ -222,6 +223,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/bb/2a/10164ed1f31196a2f7f3799368a821765c62851ead0e630ab52b8e14b4d0/blinker-1.8.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e1/76/b21bfd2c35cab2e9a4b68b1977f7488c246c8cffa31e3361ee7610e8b5af/captum-0.7.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/0a/4e/3926a1c11f0433791985727965263f788af00db3482d89a7545ca5ecc921/charset_normalizer-3.4.4-cp38-cp38-macosx_10_9_universal2.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl
|
||||
@ -545,6 +547,49 @@ packages:
|
||||
version: 5.5.2
|
||||
sha256: d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a
|
||||
requires_python: '>=3.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/e1/76/b21bfd2c35cab2e9a4b68b1977f7488c246c8cffa31e3361ee7610e8b5af/captum-0.7.0-py3-none-any.whl
|
||||
name: captum
|
||||
version: 0.7.0
|
||||
sha256: 2cbec9aa4b6ec325c2fdf369c1fdabb011017122e2314e2af009496d53a0757c
|
||||
requires_dist:
|
||||
- matplotlib
|
||||
- numpy
|
||||
- torch>=1.6
|
||||
- tqdm
|
||||
- flask ; extra == 'dev'
|
||||
- ipython ; extra == 'dev'
|
||||
- ipywidgets ; extra == 'dev'
|
||||
- jupyter ; extra == 'dev'
|
||||
- flask-compress ; extra == 'dev'
|
||||
- pytest ; extra == 'dev'
|
||||
- pytest-cov ; extra == 'dev'
|
||||
- parameterized ; extra == 'dev'
|
||||
- black==22.3.0 ; extra == 'dev'
|
||||
- flake8 ; extra == 'dev'
|
||||
- sphinx ; extra == 'dev'
|
||||
- sphinx-autodoc-typehints ; extra == 'dev'
|
||||
- sphinxcontrib-katex ; extra == 'dev'
|
||||
- mypy>=0.760 ; extra == 'dev'
|
||||
- usort==1.0.2 ; extra == 'dev'
|
||||
- ufmt ; extra == 'dev'
|
||||
- scikit-learn ; extra == 'dev'
|
||||
- annoy ; extra == 'dev'
|
||||
- flask ; extra == 'insights'
|
||||
- ipython ; extra == 'insights'
|
||||
- ipywidgets ; extra == 'insights'
|
||||
- jupyter ; extra == 'insights'
|
||||
- flask-compress ; extra == 'insights'
|
||||
- pytest ; extra == 'test'
|
||||
- pytest-cov ; extra == 'test'
|
||||
- parameterized ; extra == 'test'
|
||||
- flask ; extra == 'tutorials'
|
||||
- ipython ; extra == 'tutorials'
|
||||
- ipywidgets ; extra == 'tutorials'
|
||||
- jupyter ; extra == 'tutorials'
|
||||
- flask-compress ; extra == 'tutorials'
|
||||
- torchtext ; extra == 'tutorials'
|
||||
- torchvision ; extra == 'tutorials'
|
||||
requires_python: '>=3.6'
|
||||
- pypi: https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl
|
||||
name: certifi
|
||||
version: 2026.1.4
|
||||
|
||||
@ -29,3 +29,4 @@ streamlit = ">=1.40.1, <2"
|
||||
httpx = ">=0.28.1, <0.29"
|
||||
uvicorn = ">=0.33.0, <0.34"
|
||||
optuna = ">=4.5.0, <5"
|
||||
captum = ">=0.7.0, <0.8"
|
||||
|
||||
@ -19,3 +19,4 @@ streamlit>=1.40.1,<2
|
||||
httpx>=0.28.1,<0.29
|
||||
uvicorn>=0.33.0,<0.34
|
||||
optuna>=4.5.0,<5
|
||||
captum>=0.7.0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user