lnp_ml/app/optimize.py

1054 lines
39 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
配方优化模拟程序
通过迭代式 Grid Search 寻找最优 LNP 配方,最大化目标器官的 Biodistribution。
使用方法:
python -m app.optimize --smiles "CC(C)..." --organ liver
"""
import itertools
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from loguru import logger
from tqdm import tqdm
import typer
from lnp_ml.config import MODELS_DIR
from lnp_ml.dataset import (
LNPDataset,
LNPDatasetConfig,
collate_fn,
SMILES_COL,
COMP_COLS,
HELP_COLS,
TARGET_BIODIST,
get_phys_cols,
get_exp_cols,
)
from lnp_ml.modeling.predict import load_model
app = typer.Typer()
# ============ 参数配置 ============
# 可用的目标器官
AVAILABLE_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
@dataclass
class CompRanges:
"""组分参数范围配置mol 比例为百分数 0-100"""
# 阳离子脂质/mRNA 重量比
weight_ratio_min: float = 5.0
weight_ratio_max: float = 30.0
# 阳离子脂质 mol 比例 (%)
cationic_mol_min: float = 5.0
cationic_mol_max: float = 80.0
# 磷脂 mol 比例 (%)
phospholipid_mol_min: float = 0.0
phospholipid_mol_max: float = 80.0
# 胆固醇 mol 比例 (%)
cholesterol_mol_min: float = 0.0
cholesterol_mol_max: float = 80.0
# PEG 脂质 mol 比例 (%)
peg_mol_min: float = 0.0
peg_mol_max: float = 5.0
def to_dict(self) -> Dict:
"""转换为字典"""
return {
"weight_ratio": (self.weight_ratio_min, self.weight_ratio_max),
"cationic_mol": (self.cationic_mol_min, self.cationic_mol_max),
"phospholipid_mol": (self.phospholipid_mol_min, self.phospholipid_mol_max),
"cholesterol_mol": (self.cholesterol_mol_min, self.cholesterol_mol_max),
"peg_mol": (self.peg_mol_min, self.peg_mol_max),
}
def get_validation_error(self) -> Optional[str]:
"""
验证范围是否合理,返回错误信息(如果有)。
Returns:
错误信息字符串,如果验证通过则返回 None
"""
# 检查各范围是否有效(最小值不能大于最大值)
if self.weight_ratio_min > self.weight_ratio_max:
return f"阳离子脂质/mRNA重量比最小值({self.weight_ratio_min})不能大于最大值({self.weight_ratio_max})"
if self.cationic_mol_min > self.cationic_mol_max:
return f"阳离子脂质mol比例最小值({self.cationic_mol_min})不能大于最大值({self.cationic_mol_max})"
if self.phospholipid_mol_min > self.phospholipid_mol_max:
return f"磷脂mol比例最小值({self.phospholipid_mol_min})不能大于最大值({self.phospholipid_mol_max})"
if self.cholesterol_mol_min > self.cholesterol_mol_max:
return f"胆固醇mol比例最小值({self.cholesterol_mol_min})不能大于最大值({self.cholesterol_mol_max})"
if self.peg_mol_min > self.peg_mol_max:
return f"PEG脂质mol比例最小值({self.peg_mol_min})不能大于最大值({self.peg_mol_max})"
# 检查 mol ratio 是否可能加起来为 1
min_sum = self.cationic_mol_min + self.phospholipid_mol_min + self.cholesterol_mol_min + self.peg_mol_min
max_sum = self.cationic_mol_max + self.phospholipid_mol_max + self.cholesterol_mol_max + self.peg_mol_max
if min_sum > 100.0:
return f"mol比例最小值之和({min_sum:.1f}%)超过100%,无法生成有效配方"
if max_sum < 100.0:
return f"mol比例最大值之和({max_sum:.1f}%)不足100%,无法生成有效配方"
return None
def validate(self) -> bool:
"""验证范围是否合理(至少存在一个可行解)"""
return self.get_validation_error() is None
# 默认组分范围
DEFAULT_COMP_RANGES = CompRanges()
# PEG 最小 step size (百分数)
MIN_STEP_SIZE = 1
# 迭代策略mol ratio 的 step_size (百分数)
MOL_STEP_SIZES = [10, 2, 1]
# 迭代策略weight ratio 的 step_size与 mol ratio 解耦)
WR_STEP_SIZES = [5, 2, 1]
# Helper lipid 选项(不包含 DOTAP
HELPER_LIPID_OPTIONS = ["DOPE", "DSPC"]
# Route of administration 选项
ROUTE_OPTIONS = ["intravenous", "intramuscular"]
# delivery 统计量(由 preprocess_internal.py 生成)
# 包含: mean/stdz-score 逆变换、qd_min/qd_max评分归一化
_DELIVERY_STATS_PATH = Path(__file__).resolve().parent / "delivery_zscore_stats.json"
if _DELIVERY_STATS_PATH.exists():
with open(_DELIVERY_STATS_PATH) as _f:
DELIVERY_ZSCORE_STATS: Dict[str, Dict[str, float]] = json.load(_f)
logger.info(f"Loaded delivery stats from {_DELIVERY_STATS_PATH}")
else:
DELIVERY_ZSCORE_STATS = {}
logger.warning(f"delivery_zscore_stats.json not found at {_DELIVERY_STATS_PATH}, "
"run 'make preprocess' to generate it")
# quantified_delivery 归一化常量(从统计量中提取 qd_min/qd_max用于评分归一化到 [0,1]
DELIVERY_NORM: Dict[str, Dict[str, float]] = {}
for _route, _stats in DELIVERY_ZSCORE_STATS.items():
if "qd_min" in _stats and "qd_max" in _stats:
DELIVERY_NORM[_route] = {"min": _stats["qd_min"], "max": _stats["qd_max"]}
if not DELIVERY_NORM:
logger.warning("DELIVERY_NORM is empty — scoring normalization for delivery will be disabled")
@dataclass
class ScoringWeights:
"""
评分权重配置。
总分 = biodist_score + delivery_score + size_score + ee_score + pdi_score + toxic_score
各项计算方式参见 SCORE.md。
默认值:仅按目标器官 biodistribution 排序(向后兼容)。
"""
# 回归任务权重
biodist_weight: float = 1.0 # score = biodist_value * weight
delivery_weight: float = 0.0 # score = normalize(delivery, route) * weight
size_weight: float = 0.0 # score = (1 if 80<=size<=150 else 0) * weight
# 分类任务per-class 权重(预测为该类时,得分 = 对应权重)
ee_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) # EE class 0, 1, 2
pdi_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0, 0.0]) # PDI class 0, 1, 2, 3
toxic_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0]) # Toxic class 0, 1
# 默认评分权重(仅按 biodist 排序)
DEFAULT_SCORING_WEIGHTS = ScoringWeights()
def compute_formulation_score(
f: 'Formulation',
organ: str,
weights: ScoringWeights,
) -> float:
"""
计算单个 Formulation 的综合评分。
Args:
f: 配方对象
organ: 目标器官
weights: 评分权重
Returns:
综合评分
"""
score = 0.0
# 1. biodistribution
score += f.get_biodist(organ) * weights.biodist_weight
# 2. quantified_delivery按给药途径归一化到 [0, 1]
if f.quantified_delivery is not None and weights.delivery_weight != 0:
norm = DELIVERY_NORM.get(f.route, DELIVERY_NORM["intravenous"])
d_range = norm["max"] - norm["min"]
if d_range > 0:
delivery_normalized = (f.quantified_delivery - norm["min"]) / d_range
delivery_normalized = max(0.0, min(1.0, delivery_normalized))
else:
delivery_normalized = 0.0
score += delivery_normalized * weights.delivery_weight
# 3. size60-150nm 为理想范围)
if f.size is not None and weights.size_weight != 0:
if 60 <= f.size <= 150:
score += 1.0 * weights.size_weight
# 4. EE 分类
if f.ee_class is not None and 0 <= f.ee_class < len(weights.ee_class_weights):
score += weights.ee_class_weights[f.ee_class]
# 5. PDI 分类
if f.pdi_class is not None and 0 <= f.pdi_class < len(weights.pdi_class_weights):
score += weights.pdi_class_weights[f.pdi_class]
# 6. 毒性分类
if f.toxic_class is not None and 0 <= f.toxic_class < len(weights.toxic_class_weights):
score += weights.toxic_class_weights[f.toxic_class]
return score
def compute_df_score(
df: pd.DataFrame,
organ: str,
weights: ScoringWeights,
) -> pd.Series:
"""
为 DataFrame 中的所有行计算综合评分。
Args:
df: 包含预测结果的 DataFrame
organ: 目标器官
weights: 评分权重
Returns:
评分 Series
"""
score = pd.Series(0.0, index=df.index)
# 1. biodistribution
pred_col = f"pred_Biodistribution_{organ}"
if pred_col in df.columns:
score += df[pred_col] * weights.biodist_weight
# 2. quantified_delivery按给药途径归一化
if weights.delivery_weight != 0 and "pred_delivery" in df.columns:
for route_name, norm in DELIVERY_NORM.items():
mask = df["_route"] == route_name
if mask.any():
d_range = norm["max"] - norm["min"]
if d_range > 0:
delivery_normalized = (df.loc[mask, "pred_delivery"] - norm["min"]) / d_range
delivery_normalized = delivery_normalized.clip(0.0, 1.0)
score.loc[mask] += delivery_normalized * weights.delivery_weight
# 3. size
if weights.size_weight != 0 and "pred_size" in df.columns:
size_ok = (df["pred_size"] >= 60) & (df["pred_size"] <= 150)
score += size_ok.astype(float) * weights.size_weight
# 4. EE 分类
if "pred_ee_class" in df.columns:
for cls, w in enumerate(weights.ee_class_weights):
if w != 0:
score += (df["pred_ee_class"] == cls).astype(float) * w
# 5. PDI 分类
if "pred_pdi_class" in df.columns:
for cls, w in enumerate(weights.pdi_class_weights):
if w != 0:
score += (df["pred_pdi_class"] == cls).astype(float) * w
# 6. 毒性分类
if "pred_toxic_class" in df.columns:
for cls, w in enumerate(weights.toxic_class_weights):
if w != 0:
score += (df["pred_toxic_class"] == cls).astype(float) * w
return score
@dataclass
class Formulation:
"""配方数据结构"""
# comp token
cationic_lipid_to_mrna_ratio: float
cationic_lipid_mol_ratio: float
phospholipid_mol_ratio: float
cholesterol_mol_ratio: float
peg_lipid_mol_ratio: float
# 离散选项
helper_lipid: str = "DOPE"
route: str = "intravenous"
# 预测结果(填充后)
biodist_predictions: Dict[str, float] = field(default_factory=dict)
# 额外预测值
quantified_delivery: Optional[float] = None
unnormalized_delivery: Optional[float] = None # 反推的原始递送值z-score 逆变换)
size: Optional[float] = None
pdi_class: Optional[int] = None # PDI 分类 (0-3)
ee_class: Optional[int] = None # EE 分类 (0-2)
toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒)
def to_dict(self) -> Dict:
"""转换为字典"""
return {
"Cationic_Lipid_to_mRNA_weight_ratio": self.cationic_lipid_to_mrna_ratio,
"Cationic_Lipid_Mol_Ratio": self.cationic_lipid_mol_ratio,
"Phospholipid_Mol_Ratio": self.phospholipid_mol_ratio,
"Cholesterol_Mol_Ratio": self.cholesterol_mol_ratio,
"PEG_Lipid_Mol_Ratio": self.peg_lipid_mol_ratio,
"helper_lipid": self.helper_lipid,
"route": self.route,
}
def get_biodist(self, organ: str) -> float:
"""获取指定器官的 biodistribution 预测值"""
col = f"Biodistribution_{organ}"
return self.biodist_predictions.get(col, 0.0)
def unique_key(self) -> tuple:
"""生成唯一标识键,用于去重"""
return (
round(self.cationic_lipid_to_mrna_ratio, 4),
round(self.cationic_lipid_mol_ratio, 4),
round(self.phospholipid_mol_ratio, 4),
round(self.cholesterol_mol_ratio, 4),
round(self.peg_lipid_mol_ratio, 4),
self.helper_lipid,
self.route,
)
def generate_grid_values(
center: float,
step_size: float,
min_val: float,
max_val: float,
radius: int = 2,
) -> List[float]:
"""
围绕中心点生成网格值。
Args:
center: 中心值
step_size: 步长
min_val: 最小值
max_val: 最大值
radius: 扩展半径(生成 2*radius+1 个点)
Returns:
网格值列表
"""
values = []
for i in range(-radius, radius + 1):
val = center + i * step_size
if min_val <= val <= max_val:
values.append(round(val, 4))
return sorted(set(values))
def generate_initial_grid(
mol_step: float,
wr_step: float,
comp_ranges: CompRanges = None,
) -> List[Tuple[float, float, float, float, float]]:
"""
生成初始搜索网格(满足 mol ratio 和为 100% 的约束)。
Args:
mol_step: mol ratio 搜索步长 (百分数)
wr_step: weight ratio 搜索步长
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES
Returns:
List of (weight_ratio, cationic_mol, phospholipid_mol, cholesterol_mol, peg_mol)
"""
if comp_ranges is None:
comp_ranges = DEFAULT_COMP_RANGES
grid = []
weight_ratios = np.arange(
comp_ranges.weight_ratio_min,
comp_ranges.weight_ratio_max + 0.001,
wr_step
)
peg_values = np.arange(
comp_ranges.peg_mol_min,
comp_ranges.peg_mol_max + 0.001,
MIN_STEP_SIZE
)
for weight_ratio in weight_ratios:
for peg in peg_values:
remaining = 100.0 - peg
cationic_max = min(comp_ranges.cationic_mol_max, remaining) + 0.001
for cationic_mol in np.arange(comp_ranges.cationic_mol_min, cationic_max, mol_step):
phospholipid_max = min(comp_ranges.phospholipid_mol_max, remaining - cationic_mol) + 0.001
for phospholipid_mol in np.arange(comp_ranges.phospholipid_mol_min, phospholipid_max, mol_step):
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
if (comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max):
grid.append((
round(weight_ratio, 4),
round(cationic_mol, 4),
round(phospholipid_mol, 4),
round(cholesterol_mol, 4),
round(peg, 4),
))
return grid
def generate_refined_grid(
seeds: List[Formulation],
mol_step: float,
wr_step: float,
radius: int = 2,
comp_ranges: CompRanges = None,
) -> List[Tuple[float, float, float, float, float]]:
"""
围绕种子点生成精细化网格。
Args:
seeds: 种子配方列表
mol_step: mol ratio 步长 (百分数)
wr_step: weight ratio 步长
radius: 扩展半径
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES
Returns:
新的网格点列表
"""
if comp_ranges is None:
comp_ranges = DEFAULT_COMP_RANGES
grid_set = set()
for seed in seeds:
weight_ratios = generate_grid_values(
seed.cationic_lipid_to_mrna_ratio, wr_step,
comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius
)
peg_values = generate_grid_values(
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE,
comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius
)
cationic_mols = generate_grid_values(
seed.cationic_lipid_mol_ratio, mol_step,
comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius
)
phospholipid_mols = generate_grid_values(
seed.phospholipid_mol_ratio, mol_step,
comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius
)
for weight_ratio in weight_ratios:
for peg in peg_values:
remaining = 100.0 - peg
for cationic_mol in cationic_mols:
for phospholipid_mol in phospholipid_mols:
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and
comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and
comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and
comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max):
grid_set.add((
round(weight_ratio, 4),
round(cationic_mol, 4),
round(phospholipid_mol, 4),
round(cholesterol_mol, 4),
round(peg, 4),
))
return list(grid_set)
def create_dataframe_from_formulations(
smiles: str,
grid: List[Tuple[float, float, float, float, float]],
helper_lipids: List[str],
routes: List[str],
) -> pd.DataFrame:
"""
从配方网格创建 DataFrame。
使用固定的 phys tokenPure, Microfluidic, mRNA, FFL和 exp tokenMouse, body, luminescence
"""
rows = []
for comp_values in grid:
for helper in helper_lipids:
for route in routes:
row = {
SMILES_COL: smiles,
# comp token
"Cationic_Lipid_to_mRNA_weight_ratio": comp_values[0],
"Cationic_Lipid_Mol_Ratio": comp_values[1],
"Phospholipid_Mol_Ratio": comp_values[2],
"Cholesterol_Mol_Ratio": comp_values[3],
"PEG_Lipid_Mol_Ratio": comp_values[4],
# phys token (固定值)
"Purity_Pure": 1.0,
"Purity_Crude": 0.0,
"Mix_type_Microfluidic": 1.0,
"Mix_type_Pipetting": 0.0,
"Cargo_type_mRNA": 1.0,
"Cargo_type_pDNA": 0.0,
"Cargo_type_siRNA": 0.0,
"Target_or_delivered_gene_FFL": 1.0,
"Target_or_delivered_gene_Peptide_barcode": 0.0,
"Target_or_delivered_gene_hEPO": 0.0,
"Target_or_delivered_gene_FVII": 0.0,
"Target_or_delivered_gene_GFP": 0.0,
# help token
"Helper_lipid_ID_DOPE": 1.0 if helper == "DOPE" else 0.0,
"Helper_lipid_ID_DOTAP": 1.0 if helper == "DOTAP" else 0.0,
"Helper_lipid_ID_DSPC": 1.0 if helper == "DSPC" else 0.0,
"Helper_lipid_ID_MDOA": 0.0, # 不使用
# exp token (固定值)
"Model_type_Mouse": 1.0,
"Delivery_target_body": 1.0,
f"Route_of_administration_{route}": 1.0,
"Batch_or_individual_or_barcoded_Individual": 1.0,
"Value_name_luminescence": 1.0,
# 存储配方元信息
"_helper_lipid": helper,
"_route": route,
}
# 其他 exp token 默认为 0
for col in get_exp_cols():
if col not in row:
row[col] = 0.0
rows.append(row)
return pd.DataFrame(rows)
def predict_all(
model: torch.nn.Module,
df: pd.DataFrame,
device: torch.device,
batch_size: int = 256,
) -> pd.DataFrame:
"""
使用模型预测所有输出biodistribution、size、delivery、PDI、EE
Returns:
添加了预测列的 DataFrame
"""
dataset = LNPDataset(df)
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
all_biodist_preds = []
all_size_preds = []
all_delivery_preds = []
all_pdi_preds = []
all_ee_preds = []
all_toxic_preds = []
with torch.no_grad():
for batch in dataloader:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
outputs = model(smiles, tabular)
# biodist 输出是 softmax 后的概率分布 [B, 7]
biodist_pred = outputs["biodist"].cpu().numpy()
all_biodist_preds.append(biodist_pred)
# size 和 delivery 是回归值
all_size_preds.append(outputs["size"].squeeze(-1).cpu().numpy())
all_delivery_preds.append(outputs["delivery"].squeeze(-1).cpu().numpy())
# PDI、EE 和 toxic 是分类,取 argmax
all_pdi_preds.append(outputs["pdi"].argmax(dim=-1).cpu().numpy())
all_ee_preds.append(outputs["ee"].argmax(dim=-1).cpu().numpy())
all_toxic_preds.append(outputs["toxic"].argmax(dim=-1).cpu().numpy())
biodist_preds = np.concatenate(all_biodist_preds, axis=0)
size_preds = np.concatenate(all_size_preds, axis=0)
delivery_preds = np.concatenate(all_delivery_preds, axis=0)
pdi_preds = np.concatenate(all_pdi_preds, axis=0)
ee_preds = np.concatenate(all_ee_preds, axis=0)
toxic_preds = np.concatenate(all_toxic_preds, axis=0)
# 添加到 DataFrame
for i, col in enumerate(TARGET_BIODIST):
df[f"pred_{col}"] = biodist_preds[:, i]
# size 模型输出为 log(size),转换回真实粒径 (nm)
df["pred_size"] = np.exp(size_preds)
df["pred_delivery"] = delivery_preds
df["pred_pdi_class"] = pdi_preds
df["pred_ee_class"] = ee_preds
df["pred_toxic_class"] = toxic_preds
# 反推 unnormalized_delivery: value = z-score * std + mean
df["pred_unnorm_delivery"] = np.nan
if DELIVERY_ZSCORE_STATS:
for route_name, stats in DELIVERY_ZSCORE_STATS.items():
mask = df["_route"] == route_name
if mask.any():
df.loc[mask, "pred_unnorm_delivery"] = (
delivery_preds[mask.values] * stats["std"] + stats["mean"]
)
return df
# 保持向后兼容
def predict_biodist(
model: torch.nn.Module,
df: pd.DataFrame,
device: torch.device,
batch_size: int = 256,
) -> pd.DataFrame:
"""向后兼容的别名"""
return predict_all(model, df, device, batch_size)
def select_top_k(
df: pd.DataFrame,
organ: str,
k: int = 20,
scoring_weights: Optional[ScoringWeights] = None,
) -> List[Formulation]:
"""
选择 top-k 配方。
Args:
df: 包含预测结果的 DataFrame
organ: 目标器官
k: 选择数量
scoring_weights: 评分权重(默认仅按 biodist 排序)
Returns:
Top-k 配方列表
"""
if scoring_weights is None:
scoring_weights = DEFAULT_SCORING_WEIGHTS
# 计算综合评分并排序
df = df.copy()
df["_composite_score"] = compute_df_score(df, organ, scoring_weights)
df_sorted = df.sort_values("_composite_score", ascending=False)
# 创建配方对象
formulations = []
seen = set()
for _, row in df_sorted.iterrows():
key = (
row["Cationic_Lipid_to_mRNA_weight_ratio"],
row["Cationic_Lipid_Mol_Ratio"],
row["Phospholipid_Mol_Ratio"],
row["Cholesterol_Mol_Ratio"],
row["PEG_Lipid_Mol_Ratio"],
row["_helper_lipid"],
row["_route"],
)
if key not in seen:
seen.add(key)
unnorm_val = row.get("pred_unnorm_delivery")
unnorm_delivery = float(unnorm_val) if pd.notna(unnorm_val) else None
formulation = Formulation(
cationic_lipid_to_mrna_ratio=row["Cationic_Lipid_to_mRNA_weight_ratio"],
cationic_lipid_mol_ratio=row["Cationic_Lipid_Mol_Ratio"],
phospholipid_mol_ratio=row["Phospholipid_Mol_Ratio"],
cholesterol_mol_ratio=row["Cholesterol_Mol_Ratio"],
peg_lipid_mol_ratio=row["PEG_Lipid_Mol_Ratio"],
helper_lipid=row["_helper_lipid"],
route=row["_route"],
biodist_predictions={
col: row[f"pred_{col}"] for col in TARGET_BIODIST
},
# 额外预测值
quantified_delivery=row.get("pred_delivery"),
unnormalized_delivery=unnorm_delivery,
size=row.get("pred_size"),
pdi_class=int(row.get("pred_pdi_class")) if row.get("pred_pdi_class") is not None else None,
ee_class=int(row.get("pred_ee_class")) if row.get("pred_ee_class") is not None else None,
toxic_class=int(row.get("pred_toxic_class")) if row.get("pred_toxic_class") is not None else None,
)
formulations.append(formulation)
if len(formulations) >= k:
break
return formulations
def generate_single_seed_grid(
seed: Formulation,
mol_step: float,
wr_step: float,
radius: int = 2,
comp_ranges: CompRanges = None,
) -> List[Tuple[float, float, float, float, float]]:
"""
为单个种子点生成邻域网格。
Args:
seed: 种子配方
mol_step: mol ratio 步长 (百分数)
wr_step: weight ratio 步长
radius: 扩展半径
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES
Returns:
网格点列表
"""
if comp_ranges is None:
comp_ranges = DEFAULT_COMP_RANGES
grid_set = set()
weight_ratios = generate_grid_values(
seed.cationic_lipid_to_mrna_ratio, wr_step,
comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius
)
peg_values = generate_grid_values(
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE,
comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius
)
cationic_mols = generate_grid_values(
seed.cationic_lipid_mol_ratio, mol_step,
comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius
)
phospholipid_mols = generate_grid_values(
seed.phospholipid_mol_ratio, mol_step,
comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius
)
for weight_ratio in weight_ratios:
for peg in peg_values:
remaining = 100.0 - peg
for cationic_mol in cationic_mols:
for phospholipid_mol in phospholipid_mols:
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and
comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and
comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and
comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max):
grid_set.add((
round(weight_ratio, 4),
round(cationic_mol, 4),
round(phospholipid_mol, 4),
round(cholesterol_mol, 4),
round(peg, 4),
))
return list(grid_set)
def optimize(
smiles: str,
organ: str,
model: torch.nn.Module,
device: torch.device,
top_k: int = 20,
num_seeds: Optional[int] = None,
top_per_seed: int = 1,
step_sizes: Optional[List[float]] = None,
wr_step_sizes: Optional[List[float]] = None,
comp_ranges: Optional[CompRanges] = None,
routes: Optional[List[str]] = None,
scoring_weights: Optional[ScoringWeights] = None,
batch_size: int = 256,
) -> List[Formulation]:
"""
执行配方优化(层级搜索策略)。
采用层级搜索策略:
1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点
2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优
3. 这样可以保持搜索的多样性,避免结果集中在单一区域
Args:
smiles: SMILES 字符串
organ: 目标器官
model: 训练好的模型
device: 计算设备
top_k: 最终返回的最优配方数
num_seeds: 第一次迭代后保留的种子点数量(默认为 top_k * 5
top_per_seed: 每个种子点的邻域搜索后保留的局部最优点数量
step_sizes: mol ratio 每轮迭代的步长列表 (百分数,默认 [10, 2, 1])
wr_step_sizes: weight ratio 每轮迭代的步长列表 (默认 [5, 2, 1])
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES
routes: 给药途径列表(默认使用 ROUTE_OPTIONS
scoring_weights: 评分权重配置(默认仅按 biodist 排序)
batch_size: 预测批次大小
Returns:
最终 top-k 配方列表
"""
if num_seeds is None:
num_seeds = top_k * 5
if step_sizes is None:
step_sizes = MOL_STEP_SIZES
if wr_step_sizes is None:
wr_step_sizes = WR_STEP_SIZES
# 两组步长长度必须一致
if len(wr_step_sizes) != len(step_sizes):
raise ValueError(
f"step_sizes ({len(step_sizes)}) 和 wr_step_sizes ({len(wr_step_sizes)}) 长度不一致"
)
if comp_ranges is None:
comp_ranges = DEFAULT_COMP_RANGES
if routes is None:
routes = ROUTE_OPTIONS
if scoring_weights is None:
scoring_weights = DEFAULT_SCORING_WEIGHTS
def _score(f: Formulation) -> float:
return compute_formulation_score(f, organ, scoring_weights)
logger.info(f"Starting optimization for organ: {organ}")
logger.info(f"SMILES: {smiles}")
logger.info(f"Strategy: num_seeds={num_seeds}, top_per_seed={top_per_seed}, top_k={top_k}")
logger.info(f"Mol step sizes: {step_sizes}, WR step sizes: {wr_step_sizes}")
logger.info(f"Routes: {routes}")
logger.info(f"Scoring weights: biodist={scoring_weights.biodist_weight}, delivery={scoring_weights.delivery_weight}, size={scoring_weights.size_weight}")
logger.info(f"Comp ranges: {comp_ranges.to_dict()}")
seeds = None
for iteration, (mol_step, wr_step) in enumerate(zip(step_sizes, wr_step_sizes)):
logger.info(f"\n{'='*60}")
logger.info(f"Iteration {iteration + 1}/{len(step_sizes)}, mol_step={mol_step}, wr_step={wr_step}")
logger.info(f"{'='*60}")
if seeds is None:
# ==================== 第一次迭代:全局稀疏搜索 ====================
logger.info("Generating initial grid (global sparse search)...")
grid = generate_initial_grid(mol_step, wr_step, comp_ranges)
logger.info(f"Grid size: {len(grid)} comp combinations")
# 扩展到所有 helper lipid 和 route 组合
total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(routes)
logger.info(f"Total combinations: {total_combinations}")
# 创建 DataFrame
df = create_dataframe_from_formulations(
smiles, grid, HELPER_LIPID_OPTIONS, routes
)
# 预测
logger.info("Running predictions...")
df = predict_biodist(model, df, device, batch_size)
# 选择 top num_seeds 个种子点
seeds = select_top_k(df, organ, num_seeds, scoring_weights)
logger.info(f"Selected {len(seeds)} seeds for next iteration")
else:
# ==================== 后续迭代:层级局部搜索 ====================
# 对每个种子点分别搜索,各自保留局部最优
logger.info(f"Hierarchical local search around {len(seeds)} seeds...")
all_local_best = []
for seed_idx, seed in enumerate(seeds):
local_grid = generate_single_seed_grid(seed, mol_step, wr_step, radius=2, comp_ranges=comp_ranges)
if len(local_grid) == 0:
# 如果没有新的网格点,保留原种子
all_local_best.append(seed)
continue
# 创建 DataFrame
df = create_dataframe_from_formulations(
smiles, local_grid, [seed.helper_lipid], [seed.route]
)
# 预测
df = predict_biodist(model, df, device, batch_size)
# 选择该种子邻域内的 top top_per_seed 个局部最优
local_top = select_top_k(df, organ, top_per_seed, scoring_weights)
all_local_best.extend(local_top)
if seed_idx == 0 or (seed_idx + 1) % 5 == 0:
logger.info(f" Seed {seed_idx + 1}/{len(seeds)}: local grid size={len(local_grid)}, "
f"local best score={_score(local_top[0]):.4f}")
# 更新种子为所有局部最优点(去重)
seen_keys = set()
unique_local_best = []
# 先按综合评分排序,确保保留最优的
all_local_best_sorted = sorted(all_local_best, key=_score, reverse=True)
for f in all_local_best_sorted:
key = f.unique_key()
if key not in seen_keys:
seen_keys.add(key)
unique_local_best.append(f)
seeds = unique_local_best
logger.info(f"Collected {len(seeds)} unique local best formulations (from {len(all_local_best)} candidates)")
# 显示当前最优
best = max(seeds, key=_score)
logger.info(f"Current best score: {_score(best):.4f} (biodist_{organ}={best.get_biodist(organ):.4f})")
logger.info(f"Best formulation: {best.to_dict()}")
# 最终去重、按综合评分排序并返回 top_k
seeds_sorted = sorted(seeds, key=_score, reverse=True)
# 去重:保留每个唯一配方中得分最高的(已排序,所以第一个出现的就是最高的)
seen_keys = set()
unique_results = []
for f in seeds_sorted:
key = f.unique_key()
if key not in seen_keys:
seen_keys.add(key)
unique_results.append(f)
logger.info(f"Final results: {len(unique_results)} unique formulations (from {len(seeds)} candidates)")
return unique_results[:top_k]
def format_results(formulations: List[Formulation], organ: str) -> pd.DataFrame:
"""格式化结果为 DataFrame"""
rows = []
for i, f in enumerate(formulations):
row = {
"rank": i + 1,
f"Biodistribution_{organ}": f.get_biodist(organ),
**f.to_dict(),
}
# 添加其他器官的预测
for col in TARGET_BIODIST:
if col != f"Biodistribution_{organ}":
row[col] = f.biodist_predictions.get(col, 0.0)
rows.append(row)
return pd.DataFrame(rows)
@app.command()
def main(
smiles: str = typer.Option(..., "--smiles", "-s", help="Cationic lipid SMILES string"),
organ: str = typer.Option(..., "--organ", "-o", help=f"Target organ: {AVAILABLE_ORGANS}"),
model_path: Path = typer.Option(
MODELS_DIR / "final" / "model.pt",
"--model", "-m",
help="Path to trained model checkpoint"
),
output_path: Optional[Path] = typer.Option(
None, "--output", "-O",
help="Output CSV path (optional)"
),
top_k: int = typer.Option(20, "--top-k", "-k", help="Number of top formulations to return"),
num_seeds: Optional[int] = typer.Option(None, "--num-seeds", "-n", help="Number of seed points from first iteration (default: top_k * 5)"),
top_per_seed: int = typer.Option(1, "--top-per-seed", "-t", help="Number of local best to keep per seed"),
step_sizes: Optional[str] = typer.Option(None, "--step-sizes", "-S", help="Mol ratio step sizes, comma-separated (e.g., '10,2,1')"),
wr_step_sizes: Optional[str] = typer.Option(None, "--wr-step-sizes", help="Weight ratio step sizes, comma-separated (e.g., '5,2,1')"),
batch_size: int = typer.Option(256, "--batch-size", "-b", help="Prediction batch size"),
device: str = typer.Option("cuda" if torch.cuda.is_available() else "cpu", "--device", "-d", help="Device"),
):
"""
配方优化程序:寻找最大化目标器官 Biodistribution 的最优 LNP 配方。
采用层级搜索策略:
1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点
2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优
3. 这样可以保持搜索的多样性,避免结果集中在单一区域
示例:
python -m app.optimize --smiles "CC(C)..." --organ liver
python -m app.optimize -s "CC(C)..." -o spleen -k 10 -n 30 -t 2
python -m app.optimize -s "CC(C)..." -o liver -S "10,2,1" --wr-step-sizes "5,2,1"
"""
# 验证器官
if organ not in AVAILABLE_ORGANS:
logger.error(f"Invalid organ: {organ}. Available: {AVAILABLE_ORGANS}")
raise typer.Exit(1)
# 解析步长
parsed_step_sizes = None
if step_sizes:
try:
parsed_step_sizes = [float(s.strip()) for s in step_sizes.split(",")]
except ValueError:
logger.error(f"Invalid step sizes format: {step_sizes}")
raise typer.Exit(1)
parsed_wr_step_sizes = None
if wr_step_sizes:
try:
parsed_wr_step_sizes = [float(s.strip()) for s in wr_step_sizes.split(",")]
except ValueError:
logger.error(f"Invalid wr step sizes format: {wr_step_sizes}")
raise typer.Exit(1)
# 加载模型
logger.info(f"Loading model from {model_path}")
device = torch.device(device)
model = load_model(model_path, device)
results = optimize(
smiles=smiles,
organ=organ,
model=model,
device=device,
top_k=top_k,
num_seeds=num_seeds,
top_per_seed=top_per_seed,
step_sizes=parsed_step_sizes,
wr_step_sizes=parsed_wr_step_sizes,
batch_size=batch_size,
)
# 格式化并显示结果
df_results = format_results(results, organ)
logger.info(f"\n{'='*60}")
logger.info(f"TOP {top_k} FORMULATIONS FOR {organ.upper()}")
logger.info(f"{'='*60}")
print(df_results.to_string(index=False))
# 保存结果
if output_path:
df_results.to_csv(output_path, index=False)
logger.success(f"Results saved to {output_path}")
return df_results
if __name__ == "__main__":
app()