""" 配方优化模拟程序 通过迭代式 Grid Search 寻找最优 LNP 配方,最大化目标器官的 Biodistribution。 使用方法: python -m app.optimize --smiles "CC(C)..." --organ liver """ import itertools import json from pathlib import Path from typing import Dict, List, Optional, Tuple from dataclasses import dataclass, field import numpy as np import pandas as pd import torch from torch.utils.data import DataLoader from loguru import logger from tqdm import tqdm import typer from lnp_ml.config import MODELS_DIR from lnp_ml.dataset import ( LNPDataset, LNPDatasetConfig, collate_fn, SMILES_COL, COMP_COLS, HELP_COLS, TARGET_BIODIST, get_phys_cols, get_exp_cols, ) from lnp_ml.modeling.predict import load_model app = typer.Typer() # ============ 参数配置 ============ # 可用的目标器官 AVAILABLE_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"] @dataclass class CompRanges: """组分参数范围配置(mol 比例为百分数 0-100)""" # 阳离子脂质/mRNA 重量比 weight_ratio_min: float = 5.0 weight_ratio_max: float = 30.0 # 阳离子脂质 mol 比例 (%) cationic_mol_min: float = 5.0 cationic_mol_max: float = 80.0 # 磷脂 mol 比例 (%) phospholipid_mol_min: float = 0.0 phospholipid_mol_max: float = 80.0 # 胆固醇 mol 比例 (%) cholesterol_mol_min: float = 0.0 cholesterol_mol_max: float = 80.0 # PEG 脂质 mol 比例 (%) peg_mol_min: float = 0.0 peg_mol_max: float = 5.0 def to_dict(self) -> Dict: """转换为字典""" return { "weight_ratio": (self.weight_ratio_min, self.weight_ratio_max), "cationic_mol": (self.cationic_mol_min, self.cationic_mol_max), "phospholipid_mol": (self.phospholipid_mol_min, self.phospholipid_mol_max), "cholesterol_mol": (self.cholesterol_mol_min, self.cholesterol_mol_max), "peg_mol": (self.peg_mol_min, self.peg_mol_max), } def get_validation_error(self) -> Optional[str]: """ 验证范围是否合理,返回错误信息(如果有)。 Returns: 错误信息字符串,如果验证通过则返回 None """ # 检查各范围是否有效(最小值不能大于最大值) if self.weight_ratio_min > self.weight_ratio_max: return f"阳离子脂质/mRNA重量比:最小值({self.weight_ratio_min})不能大于最大值({self.weight_ratio_max})" if self.cationic_mol_min > self.cationic_mol_max: return f"阳离子脂质mol比例:最小值({self.cationic_mol_min})不能大于最大值({self.cationic_mol_max})" if self.phospholipid_mol_min > self.phospholipid_mol_max: return f"磷脂mol比例:最小值({self.phospholipid_mol_min})不能大于最大值({self.phospholipid_mol_max})" if self.cholesterol_mol_min > self.cholesterol_mol_max: return f"胆固醇mol比例:最小值({self.cholesterol_mol_min})不能大于最大值({self.cholesterol_mol_max})" if self.peg_mol_min > self.peg_mol_max: return f"PEG脂质mol比例:最小值({self.peg_mol_min})不能大于最大值({self.peg_mol_max})" # 检查 mol ratio 是否可能加起来为 1 min_sum = self.cationic_mol_min + self.phospholipid_mol_min + self.cholesterol_mol_min + self.peg_mol_min max_sum = self.cationic_mol_max + self.phospholipid_mol_max + self.cholesterol_mol_max + self.peg_mol_max if min_sum > 100.0: return f"mol比例最小值之和({min_sum:.1f}%)超过100%,无法生成有效配方" if max_sum < 100.0: return f"mol比例最大值之和({max_sum:.1f}%)不足100%,无法生成有效配方" return None def validate(self) -> bool: """验证范围是否合理(至少存在一个可行解)""" return self.get_validation_error() is None # 默认组分范围 DEFAULT_COMP_RANGES = CompRanges() # PEG 最小 step size (百分数) MIN_STEP_SIZE = 1 # 迭代策略:mol ratio 的 step_size (百分数) MOL_STEP_SIZES = [10, 2, 1] # 迭代策略:weight ratio 的 step_size(与 mol ratio 解耦) WR_STEP_SIZES = [5, 2, 1] # Helper lipid 选项(不包含 DOTAP) HELPER_LIPID_OPTIONS = ["DOPE", "DSPC"] # Route of administration 选项 ROUTE_OPTIONS = ["intravenous", "intramuscular"] # delivery 统计量(由 preprocess_internal.py 生成) # 包含: mean/std(z-score 逆变换)、qd_min/qd_max(评分归一化) _DELIVERY_STATS_PATH = Path(__file__).resolve().parent / "delivery_zscore_stats.json" if _DELIVERY_STATS_PATH.exists(): with open(_DELIVERY_STATS_PATH) as _f: DELIVERY_ZSCORE_STATS: Dict[str, Dict[str, float]] = json.load(_f) logger.info(f"Loaded delivery stats from {_DELIVERY_STATS_PATH}") else: DELIVERY_ZSCORE_STATS = {} logger.warning(f"delivery_zscore_stats.json not found at {_DELIVERY_STATS_PATH}, " "run 'make preprocess' to generate it") # quantified_delivery 归一化常量(从统计量中提取 qd_min/qd_max,用于评分归一化到 [0,1]) DELIVERY_NORM: Dict[str, Dict[str, float]] = {} for _route, _stats in DELIVERY_ZSCORE_STATS.items(): if "qd_min" in _stats and "qd_max" in _stats: DELIVERY_NORM[_route] = {"min": _stats["qd_min"], "max": _stats["qd_max"]} if not DELIVERY_NORM: logger.warning("DELIVERY_NORM is empty — scoring normalization for delivery will be disabled") @dataclass class ScoringWeights: """ 评分权重配置。 总分 = biodist_score + delivery_score + size_score + ee_score + pdi_score + toxic_score 各项计算方式参见 SCORE.md。 默认值:仅按目标器官 biodistribution 排序(向后兼容)。 """ # 回归任务权重 biodist_weight: float = 1.0 # score = biodist_value * weight delivery_weight: float = 0.0 # score = normalize(delivery, route) * weight size_weight: float = 0.0 # score = (1 if 80<=size<=150 else 0) * weight # 分类任务:per-class 权重(预测为该类时,得分 = 对应权重) ee_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) # EE class 0, 1, 2 pdi_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0, 0.0]) # PDI class 0, 1, 2, 3 toxic_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0]) # Toxic class 0, 1 # 默认评分权重(仅按 biodist 排序) DEFAULT_SCORING_WEIGHTS = ScoringWeights() def compute_formulation_score( f: 'Formulation', organ: str, weights: ScoringWeights, ) -> float: """ 计算单个 Formulation 的综合评分。 Args: f: 配方对象 organ: 目标器官 weights: 评分权重 Returns: 综合评分 """ score = 0.0 # 1. biodistribution score += f.get_biodist(organ) * weights.biodist_weight # 2. quantified_delivery(按给药途径归一化到 [0, 1]) if f.quantified_delivery is not None and weights.delivery_weight != 0: norm = DELIVERY_NORM.get(f.route, DELIVERY_NORM["intravenous"]) d_range = norm["max"] - norm["min"] if d_range > 0: delivery_normalized = (f.quantified_delivery - norm["min"]) / d_range delivery_normalized = max(0.0, min(1.0, delivery_normalized)) else: delivery_normalized = 0.0 score += delivery_normalized * weights.delivery_weight # 3. size(60-150nm 为理想范围) if f.size is not None and weights.size_weight != 0: if 60 <= f.size <= 150: score += 1.0 * weights.size_weight # 4. EE 分类 if f.ee_class is not None and 0 <= f.ee_class < len(weights.ee_class_weights): score += weights.ee_class_weights[f.ee_class] # 5. PDI 分类 if f.pdi_class is not None and 0 <= f.pdi_class < len(weights.pdi_class_weights): score += weights.pdi_class_weights[f.pdi_class] # 6. 毒性分类 if f.toxic_class is not None and 0 <= f.toxic_class < len(weights.toxic_class_weights): score += weights.toxic_class_weights[f.toxic_class] return score def compute_df_score( df: pd.DataFrame, organ: str, weights: ScoringWeights, ) -> pd.Series: """ 为 DataFrame 中的所有行计算综合评分。 Args: df: 包含预测结果的 DataFrame organ: 目标器官 weights: 评分权重 Returns: 评分 Series """ score = pd.Series(0.0, index=df.index) # 1. biodistribution pred_col = f"pred_Biodistribution_{organ}" if pred_col in df.columns: score += df[pred_col] * weights.biodist_weight # 2. quantified_delivery(按给药途径归一化) if weights.delivery_weight != 0 and "pred_delivery" in df.columns: for route_name, norm in DELIVERY_NORM.items(): mask = df["_route"] == route_name if mask.any(): d_range = norm["max"] - norm["min"] if d_range > 0: delivery_normalized = (df.loc[mask, "pred_delivery"] - norm["min"]) / d_range delivery_normalized = delivery_normalized.clip(0.0, 1.0) score.loc[mask] += delivery_normalized * weights.delivery_weight # 3. size if weights.size_weight != 0 and "pred_size" in df.columns: size_ok = (df["pred_size"] >= 60) & (df["pred_size"] <= 150) score += size_ok.astype(float) * weights.size_weight # 4. EE 分类 if "pred_ee_class" in df.columns: for cls, w in enumerate(weights.ee_class_weights): if w != 0: score += (df["pred_ee_class"] == cls).astype(float) * w # 5. PDI 分类 if "pred_pdi_class" in df.columns: for cls, w in enumerate(weights.pdi_class_weights): if w != 0: score += (df["pred_pdi_class"] == cls).astype(float) * w # 6. 毒性分类 if "pred_toxic_class" in df.columns: for cls, w in enumerate(weights.toxic_class_weights): if w != 0: score += (df["pred_toxic_class"] == cls).astype(float) * w return score @dataclass class Formulation: """配方数据结构""" # comp token cationic_lipid_to_mrna_ratio: float cationic_lipid_mol_ratio: float phospholipid_mol_ratio: float cholesterol_mol_ratio: float peg_lipid_mol_ratio: float # 离散选项 helper_lipid: str = "DOPE" route: str = "intravenous" # 预测结果(填充后) biodist_predictions: Dict[str, float] = field(default_factory=dict) # 额外预测值 quantified_delivery: Optional[float] = None unnormalized_delivery: Optional[float] = None # 反推的原始递送值(z-score 逆变换) size: Optional[float] = None pdi_class: Optional[int] = None # PDI 分类 (0-3) ee_class: Optional[int] = None # EE 分类 (0-2) toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒) def to_dict(self) -> Dict: """转换为字典""" return { "Cationic_Lipid_to_mRNA_weight_ratio": self.cationic_lipid_to_mrna_ratio, "Cationic_Lipid_Mol_Ratio": self.cationic_lipid_mol_ratio, "Phospholipid_Mol_Ratio": self.phospholipid_mol_ratio, "Cholesterol_Mol_Ratio": self.cholesterol_mol_ratio, "PEG_Lipid_Mol_Ratio": self.peg_lipid_mol_ratio, "helper_lipid": self.helper_lipid, "route": self.route, } def get_biodist(self, organ: str) -> float: """获取指定器官的 biodistribution 预测值""" col = f"Biodistribution_{organ}" return self.biodist_predictions.get(col, 0.0) def unique_key(self) -> tuple: """生成唯一标识键,用于去重""" return ( round(self.cationic_lipid_to_mrna_ratio, 4), round(self.cationic_lipid_mol_ratio, 4), round(self.phospholipid_mol_ratio, 4), round(self.cholesterol_mol_ratio, 4), round(self.peg_lipid_mol_ratio, 4), self.helper_lipid, self.route, ) def generate_grid_values( center: float, step_size: float, min_val: float, max_val: float, radius: int = 2, ) -> List[float]: """ 围绕中心点生成网格值。 Args: center: 中心值 step_size: 步长 min_val: 最小值 max_val: 最大值 radius: 扩展半径(生成 2*radius+1 个点) Returns: 网格值列表 """ values = [] for i in range(-radius, radius + 1): val = center + i * step_size if min_val <= val <= max_val: values.append(round(val, 4)) return sorted(set(values)) def generate_initial_grid( mol_step: float, wr_step: float, comp_ranges: CompRanges = None, ) -> List[Tuple[float, float, float, float, float]]: """ 生成初始搜索网格(满足 mol ratio 和为 100% 的约束)。 Args: mol_step: mol ratio 搜索步长 (百分数) wr_step: weight ratio 搜索步长 comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) Returns: List of (weight_ratio, cationic_mol, phospholipid_mol, cholesterol_mol, peg_mol) """ if comp_ranges is None: comp_ranges = DEFAULT_COMP_RANGES grid = [] weight_ratios = np.arange( comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max + 0.001, wr_step ) peg_values = np.arange( comp_ranges.peg_mol_min, comp_ranges.peg_mol_max + 0.001, MIN_STEP_SIZE ) for weight_ratio in weight_ratios: for peg in peg_values: remaining = 100.0 - peg cationic_max = min(comp_ranges.cationic_mol_max, remaining) + 0.001 for cationic_mol in np.arange(comp_ranges.cationic_mol_min, cationic_max, mol_step): phospholipid_max = min(comp_ranges.phospholipid_mol_max, remaining - cationic_mol) + 0.001 for phospholipid_mol in np.arange(comp_ranges.phospholipid_mol_min, phospholipid_max, mol_step): cholesterol_mol = remaining - cationic_mol - phospholipid_mol if (comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max): grid.append(( round(weight_ratio, 4), round(cationic_mol, 4), round(phospholipid_mol, 4), round(cholesterol_mol, 4), round(peg, 4), )) return grid def generate_refined_grid( seeds: List[Formulation], mol_step: float, wr_step: float, radius: int = 2, comp_ranges: CompRanges = None, ) -> List[Tuple[float, float, float, float, float]]: """ 围绕种子点生成精细化网格。 Args: seeds: 种子配方列表 mol_step: mol ratio 步长 (百分数) wr_step: weight ratio 步长 radius: 扩展半径 comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) Returns: 新的网格点列表 """ if comp_ranges is None: comp_ranges = DEFAULT_COMP_RANGES grid_set = set() for seed in seeds: weight_ratios = generate_grid_values( seed.cationic_lipid_to_mrna_ratio, wr_step, comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius ) peg_values = generate_grid_values( seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius ) cationic_mols = generate_grid_values( seed.cationic_lipid_mol_ratio, mol_step, comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius ) phospholipid_mols = generate_grid_values( seed.phospholipid_mol_ratio, mol_step, comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius ) for weight_ratio in weight_ratios: for peg in peg_values: remaining = 100.0 - peg for cationic_mol in cationic_mols: for phospholipid_mol in phospholipid_mols: cholesterol_mol = remaining - cationic_mol - phospholipid_mol if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max): grid_set.add(( round(weight_ratio, 4), round(cationic_mol, 4), round(phospholipid_mol, 4), round(cholesterol_mol, 4), round(peg, 4), )) return list(grid_set) def create_dataframe_from_formulations( smiles: str, grid: List[Tuple[float, float, float, float, float]], helper_lipids: List[str], routes: List[str], ) -> pd.DataFrame: """ 从配方网格创建 DataFrame。 使用固定的 phys token(Pure, Microfluidic, mRNA, FFL)和 exp token(Mouse, body, luminescence)。 """ rows = [] for comp_values in grid: for helper in helper_lipids: for route in routes: row = { SMILES_COL: smiles, # comp token "Cationic_Lipid_to_mRNA_weight_ratio": comp_values[0], "Cationic_Lipid_Mol_Ratio": comp_values[1], "Phospholipid_Mol_Ratio": comp_values[2], "Cholesterol_Mol_Ratio": comp_values[3], "PEG_Lipid_Mol_Ratio": comp_values[4], # phys token (固定值) "Purity_Pure": 1.0, "Purity_Crude": 0.0, "Mix_type_Microfluidic": 1.0, "Mix_type_Pipetting": 0.0, "Cargo_type_mRNA": 1.0, "Cargo_type_pDNA": 0.0, "Cargo_type_siRNA": 0.0, "Target_or_delivered_gene_FFL": 1.0, "Target_or_delivered_gene_Peptide_barcode": 0.0, "Target_or_delivered_gene_hEPO": 0.0, "Target_or_delivered_gene_FVII": 0.0, "Target_or_delivered_gene_GFP": 0.0, # help token "Helper_lipid_ID_DOPE": 1.0 if helper == "DOPE" else 0.0, "Helper_lipid_ID_DOTAP": 1.0 if helper == "DOTAP" else 0.0, "Helper_lipid_ID_DSPC": 1.0 if helper == "DSPC" else 0.0, "Helper_lipid_ID_MDOA": 0.0, # 不使用 # exp token (固定值) "Model_type_Mouse": 1.0, "Delivery_target_body": 1.0, f"Route_of_administration_{route}": 1.0, "Batch_or_individual_or_barcoded_Individual": 1.0, "Value_name_luminescence": 1.0, # 存储配方元信息 "_helper_lipid": helper, "_route": route, } # 其他 exp token 默认为 0 for col in get_exp_cols(): if col not in row: row[col] = 0.0 rows.append(row) return pd.DataFrame(rows) def predict_all( model: torch.nn.Module, df: pd.DataFrame, device: torch.device, batch_size: int = 256, ) -> pd.DataFrame: """ 使用模型预测所有输出(biodistribution、size、delivery、PDI、EE)。 Returns: 添加了预测列的 DataFrame """ dataset = LNPDataset(df) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn ) all_biodist_preds = [] all_size_preds = [] all_delivery_preds = [] all_pdi_preds = [] all_ee_preds = [] all_toxic_preds = [] with torch.no_grad(): for batch in dataloader: smiles = batch["smiles"] tabular = {k: v.to(device) for k, v in batch["tabular"].items()} outputs = model(smiles, tabular) # biodist 输出是 softmax 后的概率分布 [B, 7] biodist_pred = outputs["biodist"].cpu().numpy() all_biodist_preds.append(biodist_pred) # size 和 delivery 是回归值 all_size_preds.append(outputs["size"].squeeze(-1).cpu().numpy()) all_delivery_preds.append(outputs["delivery"].squeeze(-1).cpu().numpy()) # PDI、EE 和 toxic 是分类,取 argmax all_pdi_preds.append(outputs["pdi"].argmax(dim=-1).cpu().numpy()) all_ee_preds.append(outputs["ee"].argmax(dim=-1).cpu().numpy()) all_toxic_preds.append(outputs["toxic"].argmax(dim=-1).cpu().numpy()) biodist_preds = np.concatenate(all_biodist_preds, axis=0) size_preds = np.concatenate(all_size_preds, axis=0) delivery_preds = np.concatenate(all_delivery_preds, axis=0) pdi_preds = np.concatenate(all_pdi_preds, axis=0) ee_preds = np.concatenate(all_ee_preds, axis=0) toxic_preds = np.concatenate(all_toxic_preds, axis=0) # 添加到 DataFrame for i, col in enumerate(TARGET_BIODIST): df[f"pred_{col}"] = biodist_preds[:, i] # size 模型输出为 log(size),转换回真实粒径 (nm) df["pred_size"] = np.exp(size_preds) df["pred_delivery"] = delivery_preds df["pred_pdi_class"] = pdi_preds df["pred_ee_class"] = ee_preds df["pred_toxic_class"] = toxic_preds # 反推 unnormalized_delivery: value = z-score * std + mean df["pred_unnorm_delivery"] = np.nan if DELIVERY_ZSCORE_STATS: for route_name, stats in DELIVERY_ZSCORE_STATS.items(): mask = df["_route"] == route_name if mask.any(): df.loc[mask, "pred_unnorm_delivery"] = ( delivery_preds[mask.values] * stats["std"] + stats["mean"] ) return df # 保持向后兼容 def predict_biodist( model: torch.nn.Module, df: pd.DataFrame, device: torch.device, batch_size: int = 256, ) -> pd.DataFrame: """向后兼容的别名""" return predict_all(model, df, device, batch_size) def select_top_k( df: pd.DataFrame, organ: str, k: int = 20, scoring_weights: Optional[ScoringWeights] = None, ) -> List[Formulation]: """ 选择 top-k 配方。 Args: df: 包含预测结果的 DataFrame organ: 目标器官 k: 选择数量 scoring_weights: 评分权重(默认仅按 biodist 排序) Returns: Top-k 配方列表 """ if scoring_weights is None: scoring_weights = DEFAULT_SCORING_WEIGHTS # 计算综合评分并排序 df = df.copy() df["_composite_score"] = compute_df_score(df, organ, scoring_weights) df_sorted = df.sort_values("_composite_score", ascending=False) # 创建配方对象 formulations = [] seen = set() for _, row in df_sorted.iterrows(): key = ( row["Cationic_Lipid_to_mRNA_weight_ratio"], row["Cationic_Lipid_Mol_Ratio"], row["Phospholipid_Mol_Ratio"], row["Cholesterol_Mol_Ratio"], row["PEG_Lipid_Mol_Ratio"], row["_helper_lipid"], row["_route"], ) if key not in seen: seen.add(key) unnorm_val = row.get("pred_unnorm_delivery") unnorm_delivery = float(unnorm_val) if pd.notna(unnorm_val) else None formulation = Formulation( cationic_lipid_to_mrna_ratio=row["Cationic_Lipid_to_mRNA_weight_ratio"], cationic_lipid_mol_ratio=row["Cationic_Lipid_Mol_Ratio"], phospholipid_mol_ratio=row["Phospholipid_Mol_Ratio"], cholesterol_mol_ratio=row["Cholesterol_Mol_Ratio"], peg_lipid_mol_ratio=row["PEG_Lipid_Mol_Ratio"], helper_lipid=row["_helper_lipid"], route=row["_route"], biodist_predictions={ col: row[f"pred_{col}"] for col in TARGET_BIODIST }, # 额外预测值 quantified_delivery=row.get("pred_delivery"), unnormalized_delivery=unnorm_delivery, size=row.get("pred_size"), pdi_class=int(row.get("pred_pdi_class")) if row.get("pred_pdi_class") is not None else None, ee_class=int(row.get("pred_ee_class")) if row.get("pred_ee_class") is not None else None, toxic_class=int(row.get("pred_toxic_class")) if row.get("pred_toxic_class") is not None else None, ) formulations.append(formulation) if len(formulations) >= k: break return formulations def generate_single_seed_grid( seed: Formulation, mol_step: float, wr_step: float, radius: int = 2, comp_ranges: CompRanges = None, ) -> List[Tuple[float, float, float, float, float]]: """ 为单个种子点生成邻域网格。 Args: seed: 种子配方 mol_step: mol ratio 步长 (百分数) wr_step: weight ratio 步长 radius: 扩展半径 comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) Returns: 网格点列表 """ if comp_ranges is None: comp_ranges = DEFAULT_COMP_RANGES grid_set = set() weight_ratios = generate_grid_values( seed.cationic_lipid_to_mrna_ratio, wr_step, comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius ) peg_values = generate_grid_values( seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius ) cationic_mols = generate_grid_values( seed.cationic_lipid_mol_ratio, mol_step, comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius ) phospholipid_mols = generate_grid_values( seed.phospholipid_mol_ratio, mol_step, comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius ) for weight_ratio in weight_ratios: for peg in peg_values: remaining = 100.0 - peg for cationic_mol in cationic_mols: for phospholipid_mol in phospholipid_mols: cholesterol_mol = remaining - cationic_mol - phospholipid_mol if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max): grid_set.add(( round(weight_ratio, 4), round(cationic_mol, 4), round(phospholipid_mol, 4), round(cholesterol_mol, 4), round(peg, 4), )) return list(grid_set) def optimize( smiles: str, organ: str, model: torch.nn.Module, device: torch.device, top_k: int = 20, num_seeds: Optional[int] = None, top_per_seed: int = 1, step_sizes: Optional[List[float]] = None, wr_step_sizes: Optional[List[float]] = None, comp_ranges: Optional[CompRanges] = None, routes: Optional[List[str]] = None, scoring_weights: Optional[ScoringWeights] = None, batch_size: int = 256, ) -> List[Formulation]: """ 执行配方优化(层级搜索策略)。 采用层级搜索策略: 1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点 2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优 3. 这样可以保持搜索的多样性,避免结果集中在单一区域 Args: smiles: SMILES 字符串 organ: 目标器官 model: 训练好的模型 device: 计算设备 top_k: 最终返回的最优配方数 num_seeds: 第一次迭代后保留的种子点数量(默认为 top_k * 5) top_per_seed: 每个种子点的邻域搜索后保留的局部最优点数量 step_sizes: mol ratio 每轮迭代的步长列表 (百分数,默认 [10, 2, 1]) wr_step_sizes: weight ratio 每轮迭代的步长列表 (默认 [5, 2, 1]) comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) routes: 给药途径列表(默认使用 ROUTE_OPTIONS) scoring_weights: 评分权重配置(默认仅按 biodist 排序) batch_size: 预测批次大小 Returns: 最终 top-k 配方列表 """ if num_seeds is None: num_seeds = top_k * 5 if step_sizes is None: step_sizes = MOL_STEP_SIZES if wr_step_sizes is None: wr_step_sizes = WR_STEP_SIZES # 两组步长长度必须一致 if len(wr_step_sizes) != len(step_sizes): raise ValueError( f"step_sizes ({len(step_sizes)}) 和 wr_step_sizes ({len(wr_step_sizes)}) 长度不一致" ) if comp_ranges is None: comp_ranges = DEFAULT_COMP_RANGES if routes is None: routes = ROUTE_OPTIONS if scoring_weights is None: scoring_weights = DEFAULT_SCORING_WEIGHTS def _score(f: Formulation) -> float: return compute_formulation_score(f, organ, scoring_weights) logger.info(f"Starting optimization for organ: {organ}") logger.info(f"SMILES: {smiles}") logger.info(f"Strategy: num_seeds={num_seeds}, top_per_seed={top_per_seed}, top_k={top_k}") logger.info(f"Mol step sizes: {step_sizes}, WR step sizes: {wr_step_sizes}") logger.info(f"Routes: {routes}") logger.info(f"Scoring weights: biodist={scoring_weights.biodist_weight}, delivery={scoring_weights.delivery_weight}, size={scoring_weights.size_weight}") logger.info(f"Comp ranges: {comp_ranges.to_dict()}") seeds = None for iteration, (mol_step, wr_step) in enumerate(zip(step_sizes, wr_step_sizes)): logger.info(f"\n{'='*60}") logger.info(f"Iteration {iteration + 1}/{len(step_sizes)}, mol_step={mol_step}, wr_step={wr_step}") logger.info(f"{'='*60}") if seeds is None: # ==================== 第一次迭代:全局稀疏搜索 ==================== logger.info("Generating initial grid (global sparse search)...") grid = generate_initial_grid(mol_step, wr_step, comp_ranges) logger.info(f"Grid size: {len(grid)} comp combinations") # 扩展到所有 helper lipid 和 route 组合 total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(routes) logger.info(f"Total combinations: {total_combinations}") # 创建 DataFrame df = create_dataframe_from_formulations( smiles, grid, HELPER_LIPID_OPTIONS, routes ) # 预测 logger.info("Running predictions...") df = predict_biodist(model, df, device, batch_size) # 选择 top num_seeds 个种子点 seeds = select_top_k(df, organ, num_seeds, scoring_weights) logger.info(f"Selected {len(seeds)} seeds for next iteration") else: # ==================== 后续迭代:层级局部搜索 ==================== # 对每个种子点分别搜索,各自保留局部最优 logger.info(f"Hierarchical local search around {len(seeds)} seeds...") all_local_best = [] for seed_idx, seed in enumerate(seeds): local_grid = generate_single_seed_grid(seed, mol_step, wr_step, radius=2, comp_ranges=comp_ranges) if len(local_grid) == 0: # 如果没有新的网格点,保留原种子 all_local_best.append(seed) continue # 创建 DataFrame df = create_dataframe_from_formulations( smiles, local_grid, [seed.helper_lipid], [seed.route] ) # 预测 df = predict_biodist(model, df, device, batch_size) # 选择该种子邻域内的 top top_per_seed 个局部最优 local_top = select_top_k(df, organ, top_per_seed, scoring_weights) all_local_best.extend(local_top) if seed_idx == 0 or (seed_idx + 1) % 5 == 0: logger.info(f" Seed {seed_idx + 1}/{len(seeds)}: local grid size={len(local_grid)}, " f"local best score={_score(local_top[0]):.4f}") # 更新种子为所有局部最优点(去重) seen_keys = set() unique_local_best = [] # 先按综合评分排序,确保保留最优的 all_local_best_sorted = sorted(all_local_best, key=_score, reverse=True) for f in all_local_best_sorted: key = f.unique_key() if key not in seen_keys: seen_keys.add(key) unique_local_best.append(f) seeds = unique_local_best logger.info(f"Collected {len(seeds)} unique local best formulations (from {len(all_local_best)} candidates)") # 显示当前最优 best = max(seeds, key=_score) logger.info(f"Current best score: {_score(best):.4f} (biodist_{organ}={best.get_biodist(organ):.4f})") logger.info(f"Best formulation: {best.to_dict()}") # 最终去重、按综合评分排序并返回 top_k seeds_sorted = sorted(seeds, key=_score, reverse=True) # 去重:保留每个唯一配方中得分最高的(已排序,所以第一个出现的就是最高的) seen_keys = set() unique_results = [] for f in seeds_sorted: key = f.unique_key() if key not in seen_keys: seen_keys.add(key) unique_results.append(f) logger.info(f"Final results: {len(unique_results)} unique formulations (from {len(seeds)} candidates)") return unique_results[:top_k] def format_results(formulations: List[Formulation], organ: str) -> pd.DataFrame: """格式化结果为 DataFrame""" rows = [] for i, f in enumerate(formulations): row = { "rank": i + 1, f"Biodistribution_{organ}": f.get_biodist(organ), **f.to_dict(), } # 添加其他器官的预测 for col in TARGET_BIODIST: if col != f"Biodistribution_{organ}": row[col] = f.biodist_predictions.get(col, 0.0) rows.append(row) return pd.DataFrame(rows) @app.command() def main( smiles: str = typer.Option(..., "--smiles", "-s", help="Cationic lipid SMILES string"), organ: str = typer.Option(..., "--organ", "-o", help=f"Target organ: {AVAILABLE_ORGANS}"), model_path: Path = typer.Option( MODELS_DIR / "final" / "model.pt", "--model", "-m", help="Path to trained model checkpoint" ), output_path: Optional[Path] = typer.Option( None, "--output", "-O", help="Output CSV path (optional)" ), top_k: int = typer.Option(20, "--top-k", "-k", help="Number of top formulations to return"), num_seeds: Optional[int] = typer.Option(None, "--num-seeds", "-n", help="Number of seed points from first iteration (default: top_k * 5)"), top_per_seed: int = typer.Option(1, "--top-per-seed", "-t", help="Number of local best to keep per seed"), step_sizes: Optional[str] = typer.Option(None, "--step-sizes", "-S", help="Mol ratio step sizes, comma-separated (e.g., '10,2,1')"), wr_step_sizes: Optional[str] = typer.Option(None, "--wr-step-sizes", help="Weight ratio step sizes, comma-separated (e.g., '5,2,1')"), batch_size: int = typer.Option(256, "--batch-size", "-b", help="Prediction batch size"), device: str = typer.Option("cuda" if torch.cuda.is_available() else "cpu", "--device", "-d", help="Device"), ): """ 配方优化程序:寻找最大化目标器官 Biodistribution 的最优 LNP 配方。 采用层级搜索策略: 1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点 2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优 3. 这样可以保持搜索的多样性,避免结果集中在单一区域 示例: python -m app.optimize --smiles "CC(C)..." --organ liver python -m app.optimize -s "CC(C)..." -o spleen -k 10 -n 30 -t 2 python -m app.optimize -s "CC(C)..." -o liver -S "10,2,1" --wr-step-sizes "5,2,1" """ # 验证器官 if organ not in AVAILABLE_ORGANS: logger.error(f"Invalid organ: {organ}. Available: {AVAILABLE_ORGANS}") raise typer.Exit(1) # 解析步长 parsed_step_sizes = None if step_sizes: try: parsed_step_sizes = [float(s.strip()) for s in step_sizes.split(",")] except ValueError: logger.error(f"Invalid step sizes format: {step_sizes}") raise typer.Exit(1) parsed_wr_step_sizes = None if wr_step_sizes: try: parsed_wr_step_sizes = [float(s.strip()) for s in wr_step_sizes.split(",")] except ValueError: logger.error(f"Invalid wr step sizes format: {wr_step_sizes}") raise typer.Exit(1) # 加载模型 logger.info(f"Loading model from {model_path}") device = torch.device(device) model = load_model(model_path, device) results = optimize( smiles=smiles, organ=organ, model=model, device=device, top_k=top_k, num_seeds=num_seeds, top_per_seed=top_per_seed, step_sizes=parsed_step_sizes, wr_step_sizes=parsed_wr_step_sizes, batch_size=batch_size, ) # 格式化并显示结果 df_results = format_results(results, organ) logger.info(f"\n{'='*60}") logger.info(f"TOP {top_k} FORMULATIONS FOR {organ.upper()}") logger.info(f"{'='*60}") print(df_results.to_string(index=False)) # 保存结果 if output_path: df_results.to_csv(output_path, index=False) logger.success(f"Results saved to {output_path}") return df_results if __name__ == "__main__": app()