From 25938df4ce4038b367623a5b4b9cf2141aae0f96 Mon Sep 17 00:00:00 2001 From: RYDE-WORK Date: Tue, 3 Mar 2026 10:41:16 +0800 Subject: [PATCH] Add feature importance --- Makefile | 20 + lnp_ml/interpretability/__init__.py | 0 lnp_ml/interpretability/token_importance.py | 438 ++++++++++++++++++++ lnp_ml/modeling/layers/fusion.py | 44 +- lnp_ml/modeling/models.py | 32 ++ pixi.lock | 45 ++ pixi.toml | 1 + requirements.txt | 1 + 8 files changed, 565 insertions(+), 16 deletions(-) create mode 100644 lnp_ml/interpretability/__init__.py create mode 100644 lnp_ml/interpretability/token_importance.py diff --git a/Makefile b/Makefile index 373e155..edd9f9e 100644 --- a/Makefile +++ b/Makefile @@ -122,6 +122,26 @@ train: requirements $(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \ $(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG) +################################################################################# +# INTERPRETABILITY # +################################################################################# +# 参数: +# TASK 目标任务 (delivery, size, pdi, ee, biodist, toxic; 默认: delivery) +# METHOD 方法 (ig, ablation, attention, all; 默认: all) +# DATA 数据路径 (默认: data/processed/train.parquet) +# MODEL 模型路径 (默认: models/model.pt) + +TASK_FLAG = $(if $(TASK),--task $(TASK),) +METHOD_FLAG = $(if $(METHOD),--method $(METHOD),) +DATA_FLAG = $(if $(DATA),--data-path $(DATA),) +MODEL_FLAG = $(if $(MODEL),--model-path $(MODEL),) + +## Compute token-level feature importance +.PHONY: feature_importance +feature_importance: requirements + $(PYTHON_INTERPRETER) -m lnp_ml.interpretability.token_importance \ + $(TASK_FLAG) $(METHOD_FLAG) $(DATA_FLAG) $(MODEL_FLAG) $(DEVICE_FLAG) + ################################################################################# # SERVING & DEPLOYMENT # ################################################################################# diff --git a/lnp_ml/interpretability/__init__.py b/lnp_ml/interpretability/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lnp_ml/interpretability/token_importance.py b/lnp_ml/interpretability/token_importance.py new file mode 100644 index 0000000..6a71963 --- /dev/null +++ b/lnp_ml/interpretability/token_importance.py @@ -0,0 +1,438 @@ +""" +Token-level feature importance for the LNP multi-task model. + +Three complementary methods (inspired by AGILE / Captum): + 1. Integrated Gradients — gradient-based attribution via Captum + 2. Token Ablation — zero-out each token and measure prediction change + 3. Fusion Attention Weights — extract learned attention pooling weights + +Usage: + python -m lnp_ml.interpretability.token_importance \ + --model-path models/model.pt \ + --data-path data/processed/train.parquet \ + --task delivery \ + --method all +""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from loguru import logger +from tqdm import tqdm + +from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR, REPORTS_DIR +from lnp_ml.dataset import LNPDataset, collate_fn +from lnp_ml.modeling.predict import load_model +from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN + + +TASKS = ["delivery", "size", "pdi", "ee", "biodist", "toxic"] + + +def get_token_names(model: Union[LNPModel, LNPModelWithoutMPNN]) -> List[str]: + if model.use_mpnn: + return ["mpnn", "morgan", "maccs", "desc", "comp", "phys", "help", "exp"] + return ["morgan", "maccs", "desc", "comp", "phys", "help", "exp"] + + +# ────────────────────────────────────────────────────────────────────── +# Helper: pre-compute projected tokens for the whole dataset +# ────────────────────────────────────────────────────────────────────── + +@torch.no_grad() +def precompute_projected_tokens( + model: nn.Module, + loader: DataLoader, + device: torch.device, +) -> torch.Tensor: + """ + Run encoder + TokenProjector on all samples and stack results. + + Returns: + all_tokens: [N, n_tokens, d_model] + """ + model.eval() + chunks = [] + for batch in tqdm(loader, desc="Encoding tokens"): + smiles = batch["smiles"] + tabular = {k: v.to(device) for k, v in batch["tabular"].items()} + stacked = model._encode_and_project(smiles, tabular) + chunks.append(stacked.cpu()) + return torch.cat(chunks, dim=0) + + +# ────────────────────────────────────────────────────────────────────── +# Method 1: Integrated Gradients (Captum) +# ────────────────────────────────────────────────────────────────────── + +class _ProjectedWrapper(nn.Module): + """Wraps model.forward_from_projected so Captum can attribute w.r.t. stacked tokens.""" + + def __init__(self, model: nn.Module, task: str) -> None: + super().__init__() + self.model = model + self.task = task + + def forward(self, stacked: torch.Tensor) -> torch.Tensor: + """ + Args: + stacked: [B, n_tokens, d_model] + Returns: + [B] scalar predictions (sum over output dims for non-scalar heads) + """ + out = self.model.forward_from_projected(stacked, task=self.task) + if out.dim() > 1 and out.size(-1) > 1: + return out.sum(dim=-1) + return out.squeeze(-1) + + +def integrated_gradients_importance( + model: nn.Module, + all_tokens: torch.Tensor, + device: torch.device, + task: str = "delivery", + batch_size: int = 64, + n_steps: int = 50, +) -> np.ndarray: + """ + Compute token-level Integrated Gradients. + + Returns: + importance: [n_tokens] averaged absolute attribution per token position + """ + from captum.attr import IntegratedGradients + + wrapper = _ProjectedWrapper(model, task).to(device) + wrapper.eval() + ig = IntegratedGradients(wrapper) + + N = all_tokens.size(0) + all_attrs: List[torch.Tensor] = [] + + for start in tqdm(range(0, N, batch_size), desc=f"IG ({task})"): + end = min(start + batch_size, N) + inp = all_tokens[start:end].to(device).requires_grad_(True) + baseline = torch.zeros_like(inp) + + attr = ig.attribute(inp, baselines=baseline, n_steps=n_steps) + # attr: [B, n_tokens, d_model] → per-token L2 norm → [B, n_tokens] + token_attr = attr.detach().cpu().norm(dim=-1) + all_attrs.append(token_attr) + + attrs = torch.cat(all_attrs, dim=0) # [N, n_tokens] + importance = attrs.mean(dim=0).numpy() # [n_tokens] + return importance + + +# ────────────────────────────────────────────────────────────────────── +# Method 2: Token Ablation (zero-out) +# ────────────────────────────────────────────────────────────────────── + +@torch.no_grad() +def token_ablation_importance( + model: nn.Module, + all_tokens: torch.Tensor, + device: torch.device, + task: str = "delivery", + batch_size: int = 64, +) -> np.ndarray: + """ + For each token position, replace it with zeros and measure + the average absolute prediction change. + + Returns: + importance: [n_tokens] + """ + model.eval() + n_tokens = all_tokens.size(1) + N = all_tokens.size(0) + + # original predictions + orig_preds = _batch_predict(model, all_tokens, device, task, batch_size) + + importance = np.zeros(n_tokens) + for t in range(n_tokens): + ablated = all_tokens.clone() + ablated[:, t, :] = 0.0 + abl_preds = _batch_predict(model, ablated, device, task, batch_size) + importance[t] = np.abs(orig_preds - abl_preds).mean() + + return importance + + +def _batch_predict( + model: nn.Module, + all_tokens: torch.Tensor, + device: torch.device, + task: str, + batch_size: int, +) -> np.ndarray: + """Run forward_from_projected in batches, return [N] predictions.""" + model.eval() + preds = [] + N = all_tokens.size(0) + for start in range(0, N, batch_size): + end = min(start + batch_size, N) + inp = all_tokens[start:end].to(device) + out = model.forward_from_projected(inp, task=task) + if out.dim() > 1 and out.size(-1) > 1: + out = out.sum(dim=-1) + else: + out = out.squeeze(-1) + preds.append(out.cpu().numpy()) + return np.concatenate(preds, axis=0) + + +# ────────────────────────────────────────────────────────────────────── +# Method 3: Fusion Attention Weights +# ────────────────────────────────────────────────────────────────────── + +@torch.no_grad() +def fusion_attention_importance( + model: nn.Module, + all_tokens: torch.Tensor, + device: torch.device, + batch_size: int = 64, +) -> Optional[np.ndarray]: + """ + Extract FusionLayer attention weights (only works if strategy="attention"). + + Returns: + importance: [n_tokens] averaged attention weights, or None if not applicable. + """ + model.eval() + if model.fusion.strategy != "attention": + logger.warning("Fusion strategy is not 'attention'; skipping attention weight extraction.") + return None + + N = all_tokens.size(0) + all_weights: List[torch.Tensor] = [] + + for start in range(0, N, batch_size): + end = min(start + batch_size, N) + inp = all_tokens[start:end].to(device) + attended = model.cross_attention(inp) + _, weights = model.fusion(attended, return_attn_weights=True) + all_weights.append(weights.cpu()) + + weights = torch.cat(all_weights, dim=0) # [N, n_tokens] + return weights.mean(dim=0).numpy() + + +# ────────────────────────────────────────────────────────────────────── +# Bonus: TokenProjector gate values (static) +# ────────────────────────────────────────────────────────────────────── + +def gate_values(model: nn.Module) -> Dict[str, float]: + """Read sigmoid(weight) from TokenProjector for each token.""" + gates = {} + for key in model.token_projector.keys: + w = model.token_projector.weights[key].detach().cpu() + gates[key] = float(torch.sigmoid(w).item()) + return gates + + +# ────────────────────────────────────────────────────────────────────── +# Visualization +# ────────────────────────────────────────────────────────────────────── + +def normalize(arr: np.ndarray) -> np.ndarray: + s = arr.sum() + return arr / s if s > 0 else arr + + +def plot_token_importance( + results: Dict[str, np.ndarray], + token_names: List[str], + task: str, + out_dir: Path, +) -> Path: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + n_methods = len(results) + fig, axes = plt.subplots(1, n_methods, figsize=(6 * n_methods, max(4, len(token_names) * 0.6))) + if n_methods == 1: + axes = [axes] + + channel_a_color = "#5a448e" + channel_b_color = "#e07b39" + + for ax, (method_name, importance) in zip(axes, results.items()): + normed = normalize(importance) + order = np.argsort(normed) + + names_sorted = [token_names[i] for i in order] + vals_sorted = normed[order] + + n_tokens = len(token_names) + split_idx = 4 if n_tokens == 8 else 3 + channel_a_set = set(token_names[:split_idx]) + colors = [channel_a_color if n in channel_a_set else channel_b_color for n in names_sorted] + + ax.barh(range(len(names_sorted)), vals_sorted, color=colors) + ax.set_yticks(range(len(names_sorted))) + ax.set_yticklabels(names_sorted, fontsize=11) + ax.set_xlabel("Normalized Importance", fontsize=11) + ax.set_title(f"{method_name}", fontsize=13) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + fig.suptitle(f"Token Importance — task: {task}", fontsize=14, y=1.02) + fig.tight_layout() + + out_dir.mkdir(parents=True, exist_ok=True) + fig_path = out_dir / f"token_importance_{task}.png" + fig.savefig(fig_path, dpi=200, bbox_inches="tight") + plt.close(fig) + logger.info(f"Figure saved to {fig_path}") + return fig_path + + +def save_csv( + results: Dict[str, np.ndarray], + token_names: List[str], + task: str, + out_dir: Path, + gate_vals: Optional[Dict[str, float]] = None, +) -> Path: + out_dir.mkdir(parents=True, exist_ok=True) + df = pd.DataFrame({"token": token_names}) + for method_name, importance in results.items(): + normed = normalize(importance) + df[f"{method_name}_raw"] = importance + df[f"{method_name}_normalized"] = normed + + if gate_vals is not None: + df["gate_sigmoid"] = [gate_vals.get(t, float("nan")) for t in token_names] + + df = df.sort_values( + by=[c for c in df.columns if c.endswith("_normalized")][-1], + ascending=False, + ) + csv_path = out_dir / f"token_importance_{task}.csv" + df.to_csv(csv_path, index=False) + logger.info(f"CSV saved to {csv_path}") + return csv_path + + +# ────────────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────────────── + +def main() -> None: + parser = argparse.ArgumentParser(description="Token-level Feature Importance") + parser.add_argument("--model-path", type=str, default=str(MODELS_DIR / "model.pt"), + help="Path to trained model checkpoint") + parser.add_argument("--data-path", type=str, default=str(PROCESSED_DATA_DIR / "train.parquet"), + help="Path to data (parquet) for computing importance") + parser.add_argument("--task", type=str, default="delivery", choices=TASKS, + help="Target task for importance computation") + parser.add_argument("--method", type=str, default="ig", + choices=["ig", "ablation", "attention", "all"], + help="Which method(s) to run") + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--n-steps", type=int, default=50, + help="Number of interpolation steps for Integrated Gradients") + parser.add_argument("--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--output-dir", type=str, + default=str(REPORTS_DIR / "feature_importance")) + args = parser.parse_args() + + device = torch.device(args.device) + out_dir = Path(args.output_dir) + logger.info(f"Device: {device}") + logger.info(f"Model: {args.model_path}") + logger.info(f"Data: {args.data_path}") + logger.info(f"Task: {args.task}") + + # ── Load model ── + model = load_model(Path(args.model_path), device) + token_names = get_token_names(model) + logger.info(f"Tokens ({len(token_names)}): {token_names}") + + # ── Load data ── + df = pd.read_parquet(args.data_path) + dataset = LNPDataset(df) + loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) + logger.info(f"Samples: {len(dataset)}") + + # ── Pre-compute projected tokens ── + all_tokens = precompute_projected_tokens(model, loader, device) + logger.info(f"Projected tokens shape: {all_tokens.shape}") + + # ── Gate values (always) ── + gv = gate_values(model) + logger.info(f"TokenProjector gate values: {gv}") + + # ── Run selected methods ── + methods = ( + ["ig", "ablation", "attention"] if args.method == "all" + else [args.method] + ) + + results: Dict[str, np.ndarray] = {} + + for method in methods: + logger.info(f"\n{'='*60}") + logger.info(f"Computing: {method}") + logger.info(f"{'='*60}") + + if method == "ig": + imp = integrated_gradients_importance( + model, all_tokens, device, + task=args.task, batch_size=args.batch_size, n_steps=args.n_steps, + ) + results["Integrated Gradients"] = imp + + elif method == "ablation": + imp = token_ablation_importance( + model, all_tokens, device, + task=args.task, batch_size=args.batch_size, + ) + results["Token Ablation"] = imp + + elif method == "attention": + imp = fusion_attention_importance( + model, all_tokens, device, + batch_size=args.batch_size, + ) + if imp is not None: + results["Fusion Attention"] = imp + + # ── Print summary ── + logger.info(f"\n{'='*60}") + logger.info(f"Token Importance Summary (task={args.task})") + logger.info(f"{'='*60}") + for method_name, importance in results.items(): + normed = normalize(importance) + order = np.argsort(-normed) + logger.info(f"\n {method_name}:") + for rank, idx in enumerate(order, 1): + logger.info(f" {rank:>2d}. {token_names[idx]:<10s} {normed[idx]:.4f}") + + logger.info(f"\n Gate values (sigmoid):") + for name, val in sorted(gv.items(), key=lambda x: -x[1]): + logger.info(f" {name:<10s} {val:.4f}") + + # ── Save results ── + if results: + plot_token_importance(results, token_names, args.task, out_dir) + save_csv(results, token_names, args.task, out_dir, gate_vals=gv) + + logger.info("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/lnp_ml/modeling/layers/fusion.py b/lnp_ml/modeling/layers/fusion.py index dff583a..82c7636 100644 --- a/lnp_ml/modeling/layers/fusion.py +++ b/lnp_ml/modeling/layers/fusion.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from typing import Dict, Literal, Union +from typing import Dict, Literal, Tuple, Union PoolingStrategy = Literal["concat", "avg", "max", "attention"] @@ -48,52 +48,64 @@ class FusionLayer(nn.Module): self.attn_query = nn.Parameter(torch.randn(1, 1, d_model)) self.attn_proj = nn.Linear(d_model, d_model) - def forward(self, x: Union[Dict[str, torch.Tensor], torch.Tensor]) -> torch.Tensor: + def forward( + self, + x: Union[Dict[str, torch.Tensor], torch.Tensor], + return_attn_weights: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Args: x: Dict[str, Tensor] 每个 [B, d_model],或已 stack 的 [B, n_tokens, d_model] + return_attn_weights: 若为 True 且策略为 attention,额外返回 attn_weights [B, n_tokens] Returns: - [B, fusion_dim] + return_attn_weights=False: [B, fusion_dim] + return_attn_weights=True: ([B, fusion_dim], [B, n_tokens]) """ - # 如果输入是 dict,先 stack if isinstance(x, dict): - x = torch.stack(list(x.values()), dim=1) # [B, n_tokens, d_model] + x = torch.stack(list(x.values()), dim=1) if self.strategy == "concat": - return x.flatten(start_dim=1) # [B, n_tokens * d_model] + out = x.flatten(start_dim=1) + return (out, None) if return_attn_weights else out elif self.strategy == "avg": - return x.mean(dim=1) # [B, d_model] + out = x.mean(dim=1) + return (out, None) if return_attn_weights else out elif self.strategy == "max": - return x.max(dim=1).values # [B, d_model] + out = x.max(dim=1).values + return (out, None) if return_attn_weights else out elif self.strategy == "attention": - return self._attention_pooling(x) + return self._attention_pooling(x, return_attn_weights) else: raise ValueError(f"Unknown strategy: {self.strategy}") - def _attention_pooling(self, x: torch.Tensor) -> torch.Tensor: + def _attention_pooling( + self, x: torch.Tensor, return_attn_weights: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Attention pooling: 用可学习 query 对 tokens 做加权求和 Args: x: [B, n_tokens, d_model] + return_attn_weights: 是否返回权重 Returns: - [B, d_model] + return_attn_weights=False: [B, d_model] + return_attn_weights=True: ([B, d_model], [B, n_tokens]) """ B = x.size(0) - # query: [1, 1, d_model] -> [B, 1, d_model] query = self.attn_query.expand(B, -1, -1) - # Attention scores: [B, 1, n_tokens] - keys = self.attn_proj(x) # [B, n_tokens, d_model] + keys = self.attn_proj(x) scores = torch.bmm(query, keys.transpose(1, 2)) / (self.d_model ** 0.5) attn_weights = F.softmax(scores, dim=-1) # [B, 1, n_tokens] - # Weighted sum: [B, 1, d_model] -> [B, d_model] - out = torch.bmm(attn_weights, x).squeeze(1) + out = torch.bmm(attn_weights, x).squeeze(1) # [B, d_model] + + if return_attn_weights: + return out, attn_weights.squeeze(1) # [B, n_tokens] return out diff --git a/lnp_ml/modeling/models.py b/lnp_ml/modeling/models.py index 8b64333..893bd62 100644 --- a/lnp_ml/modeling/models.py +++ b/lnp_ml/modeling/models.py @@ -169,6 +169,38 @@ class LNPModel(nn.Module): stacked = torch.stack([projected[k] for k in token_order], dim=1) return stacked + def forward_from_projected( + self, + stacked: torch.Tensor, + task: Optional[str] = None, + ) -> torch.Tensor: + """ + 从已投影的 stacked tokens 开始 forward,用于 Captum 归因。 + + Args: + stacked: [B, n_tokens, d_model] TokenProjector 输出后 stack 的张量 + task: 指定单任务名 ("size", "pdi", "ee", "delivery", "biodist", "toxic")。 + 若为 None,返回 delivery head 的标量输出。 + + Returns: + [B, 1] 或 [B, num_classes] 对应任务的预测输出 + """ + attended = self.cross_attention(stacked) + fused = self.fusion(attended) + + if task is None: + task = "delivery" + + task_heads = { + "size": self.head.size_head, + "pdi": self.head.pdi_head, + "ee": self.head.ee_head, + "delivery": self.head.delivery_head, + "biodist": self.head.biodist_head, + "toxic": self.head.toxic_head, + } + return task_heads[task](fused) + def forward_backbone( self, smiles: List[str], diff --git a/pixi.lock b/pixi.lock index 8280a28..07bfb98 100644 --- a/pixi.lock +++ b/pixi.lock @@ -62,6 +62,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/bb/2a/10164ed1f31196a2f7f3799368a821765c62851ead0e630ab52b8e14b4d0/blinker-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e1/76/b21bfd2c35cab2e9a4b68b1977f7488c246c8cffa31e3361ee7610e8b5af/captum-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/bf/fa/cf5bb2409a385f78750e78c8d2e24780964976acdaaed65dbd6083ae5b40/charset_normalizer-3.4.4-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl @@ -222,6 +223,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/bb/2a/10164ed1f31196a2f7f3799368a821765c62851ead0e630ab52b8e14b4d0/blinker-1.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e1/76/b21bfd2c35cab2e9a4b68b1977f7488c246c8cffa31e3361ee7610e8b5af/captum-0.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0a/4e/3926a1c11f0433791985727965263f788af00db3482d89a7545ca5ecc921/charset_normalizer-3.4.4-cp38-cp38-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl @@ -545,6 +547,49 @@ packages: version: 5.5.2 sha256: d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a requires_python: '>=3.7' +- pypi: https://files.pythonhosted.org/packages/e1/76/b21bfd2c35cab2e9a4b68b1977f7488c246c8cffa31e3361ee7610e8b5af/captum-0.7.0-py3-none-any.whl + name: captum + version: 0.7.0 + sha256: 2cbec9aa4b6ec325c2fdf369c1fdabb011017122e2314e2af009496d53a0757c + requires_dist: + - matplotlib + - numpy + - torch>=1.6 + - tqdm + - flask ; extra == 'dev' + - ipython ; extra == 'dev' + - ipywidgets ; extra == 'dev' + - jupyter ; extra == 'dev' + - flask-compress ; extra == 'dev' + - pytest ; extra == 'dev' + - pytest-cov ; extra == 'dev' + - parameterized ; extra == 'dev' + - black==22.3.0 ; extra == 'dev' + - flake8 ; extra == 'dev' + - sphinx ; extra == 'dev' + - sphinx-autodoc-typehints ; extra == 'dev' + - sphinxcontrib-katex ; extra == 'dev' + - mypy>=0.760 ; extra == 'dev' + - usort==1.0.2 ; extra == 'dev' + - ufmt ; extra == 'dev' + - scikit-learn ; extra == 'dev' + - annoy ; extra == 'dev' + - flask ; extra == 'insights' + - ipython ; extra == 'insights' + - ipywidgets ; extra == 'insights' + - jupyter ; extra == 'insights' + - flask-compress ; extra == 'insights' + - pytest ; extra == 'test' + - pytest-cov ; extra == 'test' + - parameterized ; extra == 'test' + - flask ; extra == 'tutorials' + - ipython ; extra == 'tutorials' + - ipywidgets ; extra == 'tutorials' + - jupyter ; extra == 'tutorials' + - flask-compress ; extra == 'tutorials' + - torchtext ; extra == 'tutorials' + - torchvision ; extra == 'tutorials' + requires_python: '>=3.6' - pypi: https://files.pythonhosted.org/packages/e6/ad/3cc14f097111b4de0040c83a525973216457bbeeb63739ef1ed275c1c021/certifi-2026.1.4-py3-none-any.whl name: certifi version: 2026.1.4 diff --git a/pixi.toml b/pixi.toml index e42a3cf..8b01740 100644 --- a/pixi.toml +++ b/pixi.toml @@ -29,3 +29,4 @@ streamlit = ">=1.40.1, <2" httpx = ">=0.28.1, <0.29" uvicorn = ">=0.33.0, <0.34" optuna = ">=4.5.0, <5" +captum = ">=0.7.0, <0.8" diff --git a/requirements.txt b/requirements.txt index ee8553a..c73d401 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ streamlit>=1.40.1,<2 httpx>=0.28.1,<0.29 uvicorn>=0.33.0,<0.34 optuna>=4.5.0,<5 +captum>=0.7.0