mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
1017 lines
37 KiB
Python
1017 lines
37 KiB
Python
"""
|
||
配方优化模拟程序
|
||
|
||
通过迭代式 Grid Search 寻找最优 LNP 配方,最大化目标器官的 Biodistribution。
|
||
|
||
使用方法:
|
||
python -m app.optimize --smiles "CC(C)..." --organ liver
|
||
"""
|
||
|
||
import itertools
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from loguru import logger
|
||
from tqdm import tqdm
|
||
import typer
|
||
|
||
from lnp_ml.config import MODELS_DIR
|
||
from lnp_ml.dataset import (
|
||
LNPDataset,
|
||
LNPDatasetConfig,
|
||
collate_fn,
|
||
SMILES_COL,
|
||
COMP_COLS,
|
||
HELP_COLS,
|
||
TARGET_BIODIST,
|
||
get_phys_cols,
|
||
get_exp_cols,
|
||
)
|
||
from lnp_ml.modeling.predict import load_model
|
||
|
||
app = typer.Typer()
|
||
|
||
# ============ 参数配置 ============
|
||
|
||
# 可用的目标器官
|
||
AVAILABLE_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
|
||
|
||
|
||
@dataclass
|
||
class CompRanges:
|
||
"""组分参数范围配置"""
|
||
# 阳离子脂质/mRNA 重量比
|
||
weight_ratio_min: float = 0.05
|
||
weight_ratio_max: float = 0.30
|
||
# 阳离子脂质 mol 比例
|
||
cationic_mol_min: float = 0.05
|
||
cationic_mol_max: float = 0.80
|
||
# 磷脂 mol 比例
|
||
phospholipid_mol_min: float = 0.00
|
||
phospholipid_mol_max: float = 0.80
|
||
# 胆固醇 mol 比例
|
||
cholesterol_mol_min: float = 0.00
|
||
cholesterol_mol_max: float = 0.80
|
||
# PEG 脂质 mol 比例
|
||
peg_mol_min: float = 0.00
|
||
peg_mol_max: float = 0.05
|
||
|
||
def to_dict(self) -> Dict:
|
||
"""转换为字典"""
|
||
return {
|
||
"weight_ratio": (self.weight_ratio_min, self.weight_ratio_max),
|
||
"cationic_mol": (self.cationic_mol_min, self.cationic_mol_max),
|
||
"phospholipid_mol": (self.phospholipid_mol_min, self.phospholipid_mol_max),
|
||
"cholesterol_mol": (self.cholesterol_mol_min, self.cholesterol_mol_max),
|
||
"peg_mol": (self.peg_mol_min, self.peg_mol_max),
|
||
}
|
||
|
||
def get_validation_error(self) -> Optional[str]:
|
||
"""
|
||
验证范围是否合理,返回错误信息(如果有)。
|
||
|
||
Returns:
|
||
错误信息字符串,如果验证通过则返回 None
|
||
"""
|
||
# 检查各范围是否有效(最小值不能大于最大值)
|
||
if self.weight_ratio_min > self.weight_ratio_max:
|
||
return f"阳离子脂质/mRNA重量比:最小值({self.weight_ratio_min})不能大于最大值({self.weight_ratio_max})"
|
||
if self.cationic_mol_min > self.cationic_mol_max:
|
||
return f"阳离子脂质mol比例:最小值({self.cationic_mol_min})不能大于最大值({self.cationic_mol_max})"
|
||
if self.phospholipid_mol_min > self.phospholipid_mol_max:
|
||
return f"磷脂mol比例:最小值({self.phospholipid_mol_min})不能大于最大值({self.phospholipid_mol_max})"
|
||
if self.cholesterol_mol_min > self.cholesterol_mol_max:
|
||
return f"胆固醇mol比例:最小值({self.cholesterol_mol_min})不能大于最大值({self.cholesterol_mol_max})"
|
||
if self.peg_mol_min > self.peg_mol_max:
|
||
return f"PEG脂质mol比例:最小值({self.peg_mol_min})不能大于最大值({self.peg_mol_max})"
|
||
|
||
# 检查 mol ratio 是否可能加起来为 1
|
||
min_sum = self.cationic_mol_min + self.phospholipid_mol_min + self.cholesterol_mol_min + self.peg_mol_min
|
||
max_sum = self.cationic_mol_max + self.phospholipid_mol_max + self.cholesterol_mol_max + self.peg_mol_max
|
||
|
||
if min_sum > 1.0:
|
||
return f"mol比例最小值之和({min_sum:.2f})超过100%,无法生成有效配方"
|
||
if max_sum < 1.0:
|
||
return f"mol比例最大值之和({max_sum:.2f})不足100%,无法生成有效配方"
|
||
|
||
return None
|
||
|
||
def validate(self) -> bool:
|
||
"""验证范围是否合理(至少存在一个可行解)"""
|
||
return self.get_validation_error() is None
|
||
|
||
|
||
# 默认组分范围
|
||
DEFAULT_COMP_RANGES = CompRanges()
|
||
|
||
# 最小 step size
|
||
MIN_STEP_SIZE = 0.01
|
||
|
||
# 迭代策略:每个迭代的 step_size
|
||
ITERATION_STEP_SIZES = [0.10, 0.02, 0.01]
|
||
|
||
# Helper lipid 选项(不包含 DOTAP)
|
||
HELPER_LIPID_OPTIONS = ["DOPE", "DSPC"]
|
||
|
||
# Route of administration 选项
|
||
ROUTE_OPTIONS = ["intravenous", "intramuscular"]
|
||
|
||
# quantified_delivery 归一化常量(按给药途径)
|
||
DELIVERY_NORM = {
|
||
"intravenous": {"min": -0.798559291, "max": 4.497814051056962},
|
||
"intramuscular": {"min": -0.794912427, "max": 10.220042980012716},
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ScoringWeights:
|
||
"""
|
||
评分权重配置。
|
||
|
||
总分 = biodist_score + delivery_score + size_score + ee_score + pdi_score + toxic_score
|
||
各项计算方式参见 SCORE.md。
|
||
默认值:仅按目标器官 biodistribution 排序(向后兼容)。
|
||
"""
|
||
# 回归任务权重
|
||
biodist_weight: float = 1.0 # score = biodist_value * weight
|
||
delivery_weight: float = 0.0 # score = normalize(delivery, route) * weight
|
||
size_weight: float = 0.0 # score = (1 if 80<=size<=150 else 0) * weight
|
||
# 分类任务:per-class 权重(预测为该类时,得分 = 对应权重)
|
||
ee_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) # EE class 0, 1, 2
|
||
pdi_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0, 0.0]) # PDI class 0, 1, 2, 3
|
||
toxic_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0]) # Toxic class 0, 1
|
||
|
||
|
||
# 默认评分权重(仅按 biodist 排序)
|
||
DEFAULT_SCORING_WEIGHTS = ScoringWeights()
|
||
|
||
|
||
def compute_formulation_score(
|
||
f: 'Formulation',
|
||
organ: str,
|
||
weights: ScoringWeights,
|
||
) -> float:
|
||
"""
|
||
计算单个 Formulation 的综合评分。
|
||
|
||
Args:
|
||
f: 配方对象
|
||
organ: 目标器官
|
||
weights: 评分权重
|
||
|
||
Returns:
|
||
综合评分
|
||
"""
|
||
score = 0.0
|
||
|
||
# 1. biodistribution
|
||
score += f.get_biodist(organ) * weights.biodist_weight
|
||
|
||
# 2. quantified_delivery(按给药途径归一化到 [0, 1])
|
||
if f.quantified_delivery is not None and weights.delivery_weight != 0:
|
||
norm = DELIVERY_NORM.get(f.route, DELIVERY_NORM["intravenous"])
|
||
d_range = norm["max"] - norm["min"]
|
||
if d_range > 0:
|
||
delivery_normalized = (f.quantified_delivery - norm["min"]) / d_range
|
||
delivery_normalized = max(0.0, min(1.0, delivery_normalized))
|
||
else:
|
||
delivery_normalized = 0.0
|
||
score += delivery_normalized * weights.delivery_weight
|
||
|
||
# 3. size(60-150nm 为理想范围)
|
||
if f.size is not None and weights.size_weight != 0:
|
||
if 60 <= f.size <= 150:
|
||
score += 1.0 * weights.size_weight
|
||
|
||
# 4. EE 分类
|
||
if f.ee_class is not None and 0 <= f.ee_class < len(weights.ee_class_weights):
|
||
score += weights.ee_class_weights[f.ee_class]
|
||
|
||
# 5. PDI 分类
|
||
if f.pdi_class is not None and 0 <= f.pdi_class < len(weights.pdi_class_weights):
|
||
score += weights.pdi_class_weights[f.pdi_class]
|
||
|
||
# 6. 毒性分类
|
||
if f.toxic_class is not None and 0 <= f.toxic_class < len(weights.toxic_class_weights):
|
||
score += weights.toxic_class_weights[f.toxic_class]
|
||
|
||
return score
|
||
|
||
|
||
def compute_df_score(
|
||
df: pd.DataFrame,
|
||
organ: str,
|
||
weights: ScoringWeights,
|
||
) -> pd.Series:
|
||
"""
|
||
为 DataFrame 中的所有行计算综合评分。
|
||
|
||
Args:
|
||
df: 包含预测结果的 DataFrame
|
||
organ: 目标器官
|
||
weights: 评分权重
|
||
|
||
Returns:
|
||
评分 Series
|
||
"""
|
||
score = pd.Series(0.0, index=df.index)
|
||
|
||
# 1. biodistribution
|
||
pred_col = f"pred_Biodistribution_{organ}"
|
||
if pred_col in df.columns:
|
||
score += df[pred_col] * weights.biodist_weight
|
||
|
||
# 2. quantified_delivery(按给药途径归一化)
|
||
if weights.delivery_weight != 0 and "pred_delivery" in df.columns:
|
||
for route_name, norm in DELIVERY_NORM.items():
|
||
mask = df["_route"] == route_name
|
||
if mask.any():
|
||
d_range = norm["max"] - norm["min"]
|
||
if d_range > 0:
|
||
delivery_normalized = (df.loc[mask, "pred_delivery"] - norm["min"]) / d_range
|
||
delivery_normalized = delivery_normalized.clip(0.0, 1.0)
|
||
score.loc[mask] += delivery_normalized * weights.delivery_weight
|
||
|
||
# 3. size
|
||
if weights.size_weight != 0 and "pred_size" in df.columns:
|
||
size_ok = (df["pred_size"] >= 60) & (df["pred_size"] <= 150)
|
||
score += size_ok.astype(float) * weights.size_weight
|
||
|
||
# 4. EE 分类
|
||
if "pred_ee_class" in df.columns:
|
||
for cls, w in enumerate(weights.ee_class_weights):
|
||
if w != 0:
|
||
score += (df["pred_ee_class"] == cls).astype(float) * w
|
||
|
||
# 5. PDI 分类
|
||
if "pred_pdi_class" in df.columns:
|
||
for cls, w in enumerate(weights.pdi_class_weights):
|
||
if w != 0:
|
||
score += (df["pred_pdi_class"] == cls).astype(float) * w
|
||
|
||
# 6. 毒性分类
|
||
if "pred_toxic_class" in df.columns:
|
||
for cls, w in enumerate(weights.toxic_class_weights):
|
||
if w != 0:
|
||
score += (df["pred_toxic_class"] == cls).astype(float) * w
|
||
|
||
return score
|
||
|
||
|
||
@dataclass
|
||
class Formulation:
|
||
"""配方数据结构"""
|
||
# comp token
|
||
cationic_lipid_to_mrna_ratio: float
|
||
cationic_lipid_mol_ratio: float
|
||
phospholipid_mol_ratio: float
|
||
cholesterol_mol_ratio: float
|
||
peg_lipid_mol_ratio: float
|
||
# 离散选项
|
||
helper_lipid: str = "DOPE"
|
||
route: str = "intravenous"
|
||
# 预测结果(填充后)
|
||
biodist_predictions: Dict[str, float] = field(default_factory=dict)
|
||
# 额外预测值
|
||
quantified_delivery: Optional[float] = None
|
||
size: Optional[float] = None
|
||
pdi_class: Optional[int] = None # PDI 分类 (0-3)
|
||
ee_class: Optional[int] = None # EE 分类 (0-2)
|
||
toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒)
|
||
|
||
def to_dict(self) -> Dict:
|
||
"""转换为字典"""
|
||
return {
|
||
"Cationic_Lipid_to_mRNA_weight_ratio": self.cationic_lipid_to_mrna_ratio,
|
||
"Cationic_Lipid_Mol_Ratio": self.cationic_lipid_mol_ratio,
|
||
"Phospholipid_Mol_Ratio": self.phospholipid_mol_ratio,
|
||
"Cholesterol_Mol_Ratio": self.cholesterol_mol_ratio,
|
||
"PEG_Lipid_Mol_Ratio": self.peg_lipid_mol_ratio,
|
||
"helper_lipid": self.helper_lipid,
|
||
"route": self.route,
|
||
}
|
||
|
||
def get_biodist(self, organ: str) -> float:
|
||
"""获取指定器官的 biodistribution 预测值"""
|
||
col = f"Biodistribution_{organ}"
|
||
return self.biodist_predictions.get(col, 0.0)
|
||
|
||
def unique_key(self) -> tuple:
|
||
"""生成唯一标识键,用于去重"""
|
||
return (
|
||
round(self.cationic_lipid_to_mrna_ratio, 4),
|
||
round(self.cationic_lipid_mol_ratio, 4),
|
||
round(self.phospholipid_mol_ratio, 4),
|
||
round(self.cholesterol_mol_ratio, 4),
|
||
round(self.peg_lipid_mol_ratio, 4),
|
||
self.helper_lipid,
|
||
self.route,
|
||
)
|
||
|
||
|
||
def generate_grid_values(
|
||
center: float,
|
||
step_size: float,
|
||
min_val: float,
|
||
max_val: float,
|
||
radius: int = 2,
|
||
) -> List[float]:
|
||
"""
|
||
围绕中心点生成网格值。
|
||
|
||
Args:
|
||
center: 中心值
|
||
step_size: 步长
|
||
min_val: 最小值
|
||
max_val: 最大值
|
||
radius: 扩展半径(生成 2*radius+1 个点)
|
||
|
||
Returns:
|
||
网格值列表
|
||
"""
|
||
values = []
|
||
for i in range(-radius, radius + 1):
|
||
val = center + i * step_size
|
||
if min_val <= val <= max_val:
|
||
values.append(round(val, 4))
|
||
return sorted(set(values))
|
||
|
||
|
||
def generate_initial_grid(
|
||
step_size: float,
|
||
comp_ranges: CompRanges = None,
|
||
) -> List[Tuple[float, float, float, float, float]]:
|
||
"""
|
||
生成初始搜索网格(满足 mol ratio 和为 1 的约束)。
|
||
|
||
Args:
|
||
step_size: 搜索步长
|
||
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES)
|
||
|
||
Returns:
|
||
List of (cationic_ratio, cationic_mol, phospholipid_mol, cholesterol_mol, peg_mol)
|
||
"""
|
||
if comp_ranges is None:
|
||
comp_ranges = DEFAULT_COMP_RANGES
|
||
|
||
grid = []
|
||
|
||
# Cationic_Lipid_to_mRNA_weight_ratio
|
||
weight_ratios = np.arange(
|
||
comp_ranges.weight_ratio_min,
|
||
comp_ranges.weight_ratio_max + 0.001,
|
||
step_size
|
||
)
|
||
|
||
# PEG: 单独处理,范围很小,始终用最小步长
|
||
peg_values = np.arange(
|
||
comp_ranges.peg_mol_min,
|
||
comp_ranges.peg_mol_max + 0.001,
|
||
MIN_STEP_SIZE
|
||
)
|
||
|
||
# 其他三个 mol ratio 需要满足和为 1 - PEG
|
||
mol_step = step_size
|
||
|
||
for weight_ratio in weight_ratios:
|
||
for peg in peg_values:
|
||
remaining = 1.0 - peg
|
||
# 生成满足约束的组合
|
||
cationic_max = min(comp_ranges.cationic_mol_max, remaining) + 0.001
|
||
for cationic_mol in np.arange(comp_ranges.cationic_mol_min, cationic_max, mol_step):
|
||
phospholipid_max = min(comp_ranges.phospholipid_mol_max, remaining - cationic_mol) + 0.001
|
||
for phospholipid_mol in np.arange(comp_ranges.phospholipid_mol_min, phospholipid_max, mol_step):
|
||
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
|
||
# 检查约束
|
||
if (comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max):
|
||
grid.append((
|
||
round(weight_ratio, 4),
|
||
round(cationic_mol, 4),
|
||
round(phospholipid_mol, 4),
|
||
round(cholesterol_mol, 4),
|
||
round(peg, 4),
|
||
))
|
||
|
||
return grid
|
||
|
||
|
||
def generate_refined_grid(
|
||
seeds: List[Formulation],
|
||
step_size: float,
|
||
radius: int = 2,
|
||
comp_ranges: CompRanges = None,
|
||
) -> List[Tuple[float, float, float, float, float]]:
|
||
"""
|
||
围绕种子点生成精细化网格。
|
||
|
||
Args:
|
||
seeds: 种子配方列表
|
||
step_size: 步长
|
||
radius: 扩展半径
|
||
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES)
|
||
|
||
Returns:
|
||
新的网格点列表
|
||
"""
|
||
if comp_ranges is None:
|
||
comp_ranges = DEFAULT_COMP_RANGES
|
||
|
||
grid_set = set()
|
||
|
||
for seed in seeds:
|
||
# Weight ratio
|
||
weight_ratios = generate_grid_values(
|
||
seed.cationic_lipid_to_mrna_ratio, step_size,
|
||
comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius
|
||
)
|
||
|
||
# PEG (始终用最小步长)
|
||
peg_values = generate_grid_values(
|
||
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE,
|
||
comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius
|
||
)
|
||
|
||
# Mol ratios
|
||
cationic_mols = generate_grid_values(
|
||
seed.cationic_lipid_mol_ratio, step_size,
|
||
comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius
|
||
)
|
||
phospholipid_mols = generate_grid_values(
|
||
seed.phospholipid_mol_ratio, step_size,
|
||
comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius
|
||
)
|
||
|
||
for weight_ratio in weight_ratios:
|
||
for peg in peg_values:
|
||
remaining = 1.0 - peg
|
||
for cationic_mol in cationic_mols:
|
||
for phospholipid_mol in phospholipid_mols:
|
||
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
|
||
# 检查约束
|
||
if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and
|
||
comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and
|
||
comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and
|
||
comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max):
|
||
grid_set.add((
|
||
round(weight_ratio, 4),
|
||
round(cationic_mol, 4),
|
||
round(phospholipid_mol, 4),
|
||
round(cholesterol_mol, 4),
|
||
round(peg, 4),
|
||
))
|
||
|
||
return list(grid_set)
|
||
|
||
|
||
def create_dataframe_from_formulations(
|
||
smiles: str,
|
||
grid: List[Tuple[float, float, float, float, float]],
|
||
helper_lipids: List[str],
|
||
routes: List[str],
|
||
) -> pd.DataFrame:
|
||
"""
|
||
从配方网格创建 DataFrame。
|
||
|
||
使用固定的 phys token(Pure, Microfluidic, mRNA, FFL)和 exp token(Mouse, body, luminescence)。
|
||
"""
|
||
rows = []
|
||
|
||
for comp_values in grid:
|
||
for helper in helper_lipids:
|
||
for route in routes:
|
||
row = {
|
||
SMILES_COL: smiles,
|
||
# comp token
|
||
"Cationic_Lipid_to_mRNA_weight_ratio": comp_values[0],
|
||
"Cationic_Lipid_Mol_Ratio": comp_values[1],
|
||
"Phospholipid_Mol_Ratio": comp_values[2],
|
||
"Cholesterol_Mol_Ratio": comp_values[3],
|
||
"PEG_Lipid_Mol_Ratio": comp_values[4],
|
||
# phys token (固定值)
|
||
"Purity_Pure": 1.0,
|
||
"Purity_Crude": 0.0,
|
||
"Mix_type_Microfluidic": 1.0,
|
||
"Mix_type_Pipetting": 0.0,
|
||
"Cargo_type_mRNA": 1.0,
|
||
"Cargo_type_pDNA": 0.0,
|
||
"Cargo_type_siRNA": 0.0,
|
||
"Target_or_delivered_gene_FFL": 1.0,
|
||
"Target_or_delivered_gene_Peptide_barcode": 0.0,
|
||
"Target_or_delivered_gene_hEPO": 0.0,
|
||
"Target_or_delivered_gene_FVII": 0.0,
|
||
"Target_or_delivered_gene_GFP": 0.0,
|
||
# help token
|
||
"Helper_lipid_ID_DOPE": 1.0 if helper == "DOPE" else 0.0,
|
||
"Helper_lipid_ID_DOTAP": 1.0 if helper == "DOTAP" else 0.0,
|
||
"Helper_lipid_ID_DSPC": 1.0 if helper == "DSPC" else 0.0,
|
||
"Helper_lipid_ID_MDOA": 0.0, # 不使用
|
||
# exp token (固定值)
|
||
"Model_type_Mouse": 1.0,
|
||
"Delivery_target_body": 1.0,
|
||
f"Route_of_administration_{route}": 1.0,
|
||
"Batch_or_individual_or_barcoded_Individual": 1.0,
|
||
"Value_name_luminescence": 1.0,
|
||
# 存储配方元信息
|
||
"_helper_lipid": helper,
|
||
"_route": route,
|
||
}
|
||
# 其他 exp token 默认为 0
|
||
for col in get_exp_cols():
|
||
if col not in row:
|
||
row[col] = 0.0
|
||
|
||
rows.append(row)
|
||
|
||
return pd.DataFrame(rows)
|
||
|
||
|
||
def predict_all(
|
||
model: torch.nn.Module,
|
||
df: pd.DataFrame,
|
||
device: torch.device,
|
||
batch_size: int = 256,
|
||
) -> pd.DataFrame:
|
||
"""
|
||
使用模型预测所有输出(biodistribution、size、delivery、PDI、EE)。
|
||
|
||
Returns:
|
||
添加了预测列的 DataFrame
|
||
"""
|
||
dataset = LNPDataset(df)
|
||
dataloader = DataLoader(
|
||
dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||
)
|
||
|
||
all_biodist_preds = []
|
||
all_size_preds = []
|
||
all_delivery_preds = []
|
||
all_pdi_preds = []
|
||
all_ee_preds = []
|
||
all_toxic_preds = []
|
||
|
||
with torch.no_grad():
|
||
for batch in dataloader:
|
||
smiles = batch["smiles"]
|
||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||
|
||
outputs = model(smiles, tabular)
|
||
|
||
# biodist 输出是 softmax 后的概率分布 [B, 7]
|
||
biodist_pred = outputs["biodist"].cpu().numpy()
|
||
all_biodist_preds.append(biodist_pred)
|
||
|
||
# size 和 delivery 是回归值
|
||
all_size_preds.append(outputs["size"].squeeze(-1).cpu().numpy())
|
||
all_delivery_preds.append(outputs["delivery"].squeeze(-1).cpu().numpy())
|
||
|
||
# PDI、EE 和 toxic 是分类,取 argmax
|
||
all_pdi_preds.append(outputs["pdi"].argmax(dim=-1).cpu().numpy())
|
||
all_ee_preds.append(outputs["ee"].argmax(dim=-1).cpu().numpy())
|
||
all_toxic_preds.append(outputs["toxic"].argmax(dim=-1).cpu().numpy())
|
||
|
||
biodist_preds = np.concatenate(all_biodist_preds, axis=0)
|
||
size_preds = np.concatenate(all_size_preds, axis=0)
|
||
delivery_preds = np.concatenate(all_delivery_preds, axis=0)
|
||
pdi_preds = np.concatenate(all_pdi_preds, axis=0)
|
||
ee_preds = np.concatenate(all_ee_preds, axis=0)
|
||
toxic_preds = np.concatenate(all_toxic_preds, axis=0)
|
||
|
||
# 添加到 DataFrame
|
||
for i, col in enumerate(TARGET_BIODIST):
|
||
df[f"pred_{col}"] = biodist_preds[:, i]
|
||
|
||
# size 模型输出为 log(size),转换回真实粒径 (nm)
|
||
df["pred_size"] = np.exp(size_preds)
|
||
df["pred_delivery"] = delivery_preds
|
||
df["pred_pdi_class"] = pdi_preds
|
||
df["pred_ee_class"] = ee_preds
|
||
df["pred_toxic_class"] = toxic_preds
|
||
|
||
return df
|
||
|
||
|
||
# 保持向后兼容
|
||
def predict_biodist(
|
||
model: torch.nn.Module,
|
||
df: pd.DataFrame,
|
||
device: torch.device,
|
||
batch_size: int = 256,
|
||
) -> pd.DataFrame:
|
||
"""向后兼容的别名"""
|
||
return predict_all(model, df, device, batch_size)
|
||
|
||
|
||
def select_top_k(
|
||
df: pd.DataFrame,
|
||
organ: str,
|
||
k: int = 20,
|
||
scoring_weights: Optional[ScoringWeights] = None,
|
||
) -> List[Formulation]:
|
||
"""
|
||
选择 top-k 配方。
|
||
|
||
Args:
|
||
df: 包含预测结果的 DataFrame
|
||
organ: 目标器官
|
||
k: 选择数量
|
||
scoring_weights: 评分权重(默认仅按 biodist 排序)
|
||
|
||
Returns:
|
||
Top-k 配方列表
|
||
"""
|
||
if scoring_weights is None:
|
||
scoring_weights = DEFAULT_SCORING_WEIGHTS
|
||
|
||
# 计算综合评分并排序
|
||
df = df.copy()
|
||
df["_composite_score"] = compute_df_score(df, organ, scoring_weights)
|
||
df_sorted = df.sort_values("_composite_score", ascending=False)
|
||
|
||
# 创建配方对象
|
||
formulations = []
|
||
seen = set()
|
||
|
||
for _, row in df_sorted.iterrows():
|
||
key = (
|
||
row["Cationic_Lipid_to_mRNA_weight_ratio"],
|
||
row["Cationic_Lipid_Mol_Ratio"],
|
||
row["Phospholipid_Mol_Ratio"],
|
||
row["Cholesterol_Mol_Ratio"],
|
||
row["PEG_Lipid_Mol_Ratio"],
|
||
row["_helper_lipid"],
|
||
row["_route"],
|
||
)
|
||
|
||
if key not in seen:
|
||
seen.add(key)
|
||
formulation = Formulation(
|
||
cationic_lipid_to_mrna_ratio=row["Cationic_Lipid_to_mRNA_weight_ratio"],
|
||
cationic_lipid_mol_ratio=row["Cationic_Lipid_Mol_Ratio"],
|
||
phospholipid_mol_ratio=row["Phospholipid_Mol_Ratio"],
|
||
cholesterol_mol_ratio=row["Cholesterol_Mol_Ratio"],
|
||
peg_lipid_mol_ratio=row["PEG_Lipid_Mol_Ratio"],
|
||
helper_lipid=row["_helper_lipid"],
|
||
route=row["_route"],
|
||
biodist_predictions={
|
||
col: row[f"pred_{col}"] for col in TARGET_BIODIST
|
||
},
|
||
# 额外预测值
|
||
quantified_delivery=row.get("pred_delivery"),
|
||
size=row.get("pred_size"),
|
||
pdi_class=int(row.get("pred_pdi_class")) if row.get("pred_pdi_class") is not None else None,
|
||
ee_class=int(row.get("pred_ee_class")) if row.get("pred_ee_class") is not None else None,
|
||
toxic_class=int(row.get("pred_toxic_class")) if row.get("pred_toxic_class") is not None else None,
|
||
)
|
||
formulations.append(formulation)
|
||
|
||
if len(formulations) >= k:
|
||
break
|
||
|
||
return formulations
|
||
|
||
|
||
def generate_single_seed_grid(
|
||
seed: Formulation,
|
||
step_size: float,
|
||
radius: int = 2,
|
||
comp_ranges: CompRanges = None,
|
||
) -> List[Tuple[float, float, float, float, float]]:
|
||
"""
|
||
为单个种子点生成邻域网格。
|
||
|
||
Args:
|
||
seed: 种子配方
|
||
step_size: 步长
|
||
radius: 扩展半径
|
||
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES)
|
||
|
||
Returns:
|
||
网格点列表
|
||
"""
|
||
if comp_ranges is None:
|
||
comp_ranges = DEFAULT_COMP_RANGES
|
||
|
||
grid_set = set()
|
||
|
||
# Weight ratio
|
||
weight_ratios = generate_grid_values(
|
||
seed.cationic_lipid_to_mrna_ratio, step_size,
|
||
comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius
|
||
)
|
||
|
||
# PEG (始终用最小步长)
|
||
peg_values = generate_grid_values(
|
||
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE,
|
||
comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius
|
||
)
|
||
|
||
# Mol ratios
|
||
cationic_mols = generate_grid_values(
|
||
seed.cationic_lipid_mol_ratio, step_size,
|
||
comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius
|
||
)
|
||
phospholipid_mols = generate_grid_values(
|
||
seed.phospholipid_mol_ratio, step_size,
|
||
comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius
|
||
)
|
||
|
||
for weight_ratio in weight_ratios:
|
||
for peg in peg_values:
|
||
remaining = 1.0 - peg
|
||
for cationic_mol in cationic_mols:
|
||
for phospholipid_mol in phospholipid_mols:
|
||
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
|
||
# 检查约束
|
||
if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and
|
||
comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and
|
||
comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and
|
||
comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max):
|
||
grid_set.add((
|
||
round(weight_ratio, 4),
|
||
round(cationic_mol, 4),
|
||
round(phospholipid_mol, 4),
|
||
round(cholesterol_mol, 4),
|
||
round(peg, 4),
|
||
))
|
||
|
||
return list(grid_set)
|
||
|
||
|
||
def optimize(
|
||
smiles: str,
|
||
organ: str,
|
||
model: torch.nn.Module,
|
||
device: torch.device,
|
||
top_k: int = 20,
|
||
num_seeds: Optional[int] = None,
|
||
top_per_seed: int = 1,
|
||
step_sizes: Optional[List[float]] = None,
|
||
comp_ranges: Optional[CompRanges] = None,
|
||
routes: Optional[List[str]] = None,
|
||
scoring_weights: Optional[ScoringWeights] = None,
|
||
batch_size: int = 256,
|
||
) -> List[Formulation]:
|
||
"""
|
||
执行配方优化(层级搜索策略)。
|
||
|
||
采用层级搜索策略:
|
||
1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点
|
||
2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优
|
||
3. 这样可以保持搜索的多样性,避免结果集中在单一区域
|
||
|
||
Args:
|
||
smiles: SMILES 字符串
|
||
organ: 目标器官
|
||
model: 训练好的模型
|
||
device: 计算设备
|
||
top_k: 最终返回的最优配方数
|
||
num_seeds: 第一次迭代后保留的种子点数量(默认为 top_k * 5)
|
||
top_per_seed: 每个种子点的邻域搜索后保留的局部最优点数量
|
||
step_sizes: 每轮迭代的步长列表(默认为 [0.10, 0.02, 0.01])
|
||
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES)
|
||
routes: 给药途径列表(默认使用 ROUTE_OPTIONS)
|
||
scoring_weights: 评分权重配置(默认仅按 biodist 排序)
|
||
batch_size: 预测批次大小
|
||
|
||
Returns:
|
||
最终 top-k 配方列表
|
||
"""
|
||
# 默认 num_seeds 为 top_k * 5
|
||
if num_seeds is None:
|
||
num_seeds = top_k * 5
|
||
|
||
# 默认步长
|
||
if step_sizes is None:
|
||
step_sizes = ITERATION_STEP_SIZES
|
||
|
||
# 默认组分范围
|
||
if comp_ranges is None:
|
||
comp_ranges = DEFAULT_COMP_RANGES
|
||
|
||
# 默认给药途径
|
||
if routes is None:
|
||
routes = ROUTE_OPTIONS
|
||
|
||
# 默认评分权重
|
||
if scoring_weights is None:
|
||
scoring_weights = DEFAULT_SCORING_WEIGHTS
|
||
|
||
# 评分函数(用于 Formulation 对象排序)
|
||
def _score(f: Formulation) -> float:
|
||
return compute_formulation_score(f, organ, scoring_weights)
|
||
|
||
logger.info(f"Starting optimization for organ: {organ}")
|
||
logger.info(f"SMILES: {smiles}")
|
||
logger.info(f"Strategy: num_seeds={num_seeds}, top_per_seed={top_per_seed}, top_k={top_k}")
|
||
logger.info(f"Step sizes: {step_sizes}")
|
||
logger.info(f"Routes: {routes}")
|
||
logger.info(f"Scoring weights: biodist={scoring_weights.biodist_weight}, delivery={scoring_weights.delivery_weight}, size={scoring_weights.size_weight}")
|
||
logger.info(f"Comp ranges: {comp_ranges.to_dict()}")
|
||
|
||
seeds = None
|
||
|
||
for iteration, step_size in enumerate(step_sizes):
|
||
logger.info(f"\n{'='*60}")
|
||
logger.info(f"Iteration {iteration + 1}/{len(step_sizes)}, step_size={step_size}")
|
||
logger.info(f"{'='*60}")
|
||
|
||
if seeds is None:
|
||
# ==================== 第一次迭代:全局稀疏搜索 ====================
|
||
logger.info("Generating initial grid (global sparse search)...")
|
||
grid = generate_initial_grid(step_size, comp_ranges)
|
||
|
||
logger.info(f"Grid size: {len(grid)} comp combinations")
|
||
|
||
# 扩展到所有 helper lipid 和 route 组合
|
||
total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(routes)
|
||
logger.info(f"Total combinations: {total_combinations}")
|
||
|
||
# 创建 DataFrame
|
||
df = create_dataframe_from_formulations(
|
||
smiles, grid, HELPER_LIPID_OPTIONS, routes
|
||
)
|
||
|
||
# 预测
|
||
logger.info("Running predictions...")
|
||
df = predict_biodist(model, df, device, batch_size)
|
||
|
||
# 选择 top num_seeds 个种子点
|
||
seeds = select_top_k(df, organ, num_seeds, scoring_weights)
|
||
|
||
logger.info(f"Selected {len(seeds)} seeds for next iteration")
|
||
|
||
else:
|
||
# ==================== 后续迭代:层级局部搜索 ====================
|
||
# 对每个种子点分别搜索,各自保留局部最优
|
||
logger.info(f"Hierarchical local search around {len(seeds)} seeds...")
|
||
|
||
all_local_best = []
|
||
|
||
for seed_idx, seed in enumerate(seeds):
|
||
# 为当前种子点生成邻域网格
|
||
local_grid = generate_single_seed_grid(seed, step_size, radius=2, comp_ranges=comp_ranges)
|
||
|
||
if len(local_grid) == 0:
|
||
# 如果没有新的网格点,保留原种子
|
||
all_local_best.append(seed)
|
||
continue
|
||
|
||
# 创建 DataFrame
|
||
df = create_dataframe_from_formulations(
|
||
smiles, local_grid, [seed.helper_lipid], [seed.route]
|
||
)
|
||
|
||
# 预测
|
||
df = predict_biodist(model, df, device, batch_size)
|
||
|
||
# 选择该种子邻域内的 top top_per_seed 个局部最优
|
||
local_top = select_top_k(df, organ, top_per_seed, scoring_weights)
|
||
all_local_best.extend(local_top)
|
||
|
||
if seed_idx == 0 or (seed_idx + 1) % 5 == 0:
|
||
logger.info(f" Seed {seed_idx + 1}/{len(seeds)}: local grid size={len(local_grid)}, "
|
||
f"local best score={_score(local_top[0]):.4f}")
|
||
|
||
# 更新种子为所有局部最优点(去重)
|
||
seen_keys = set()
|
||
unique_local_best = []
|
||
# 先按综合评分排序,确保保留最优的
|
||
all_local_best_sorted = sorted(all_local_best, key=_score, reverse=True)
|
||
for f in all_local_best_sorted:
|
||
key = f.unique_key()
|
||
if key not in seen_keys:
|
||
seen_keys.add(key)
|
||
unique_local_best.append(f)
|
||
|
||
seeds = unique_local_best
|
||
logger.info(f"Collected {len(seeds)} unique local best formulations (from {len(all_local_best)} candidates)")
|
||
|
||
# 显示当前最优
|
||
best = max(seeds, key=_score)
|
||
logger.info(f"Current best score: {_score(best):.4f} (biodist_{organ}={best.get_biodist(organ):.4f})")
|
||
logger.info(f"Best formulation: {best.to_dict()}")
|
||
|
||
# 最终去重、按综合评分排序并返回 top_k
|
||
seeds_sorted = sorted(seeds, key=_score, reverse=True)
|
||
|
||
# 去重:保留每个唯一配方中得分最高的(已排序,所以第一个出现的就是最高的)
|
||
seen_keys = set()
|
||
unique_results = []
|
||
for f in seeds_sorted:
|
||
key = f.unique_key()
|
||
if key not in seen_keys:
|
||
seen_keys.add(key)
|
||
unique_results.append(f)
|
||
|
||
logger.info(f"Final results: {len(unique_results)} unique formulations (from {len(seeds)} candidates)")
|
||
|
||
return unique_results[:top_k]
|
||
|
||
|
||
def format_results(formulations: List[Formulation], organ: str) -> pd.DataFrame:
|
||
"""格式化结果为 DataFrame"""
|
||
rows = []
|
||
for i, f in enumerate(formulations):
|
||
row = {
|
||
"rank": i + 1,
|
||
f"Biodistribution_{organ}": f.get_biodist(organ),
|
||
**f.to_dict(),
|
||
}
|
||
# 添加其他器官的预测
|
||
for col in TARGET_BIODIST:
|
||
if col != f"Biodistribution_{organ}":
|
||
row[col] = f.biodist_predictions.get(col, 0.0)
|
||
rows.append(row)
|
||
return pd.DataFrame(rows)
|
||
|
||
|
||
@app.command()
|
||
def main(
|
||
smiles: str = typer.Option(..., "--smiles", "-s", help="Cationic lipid SMILES string"),
|
||
organ: str = typer.Option(..., "--organ", "-o", help=f"Target organ: {AVAILABLE_ORGANS}"),
|
||
model_path: Path = typer.Option(
|
||
MODELS_DIR / "final" / "model.pt",
|
||
"--model", "-m",
|
||
help="Path to trained model checkpoint"
|
||
),
|
||
output_path: Optional[Path] = typer.Option(
|
||
None, "--output", "-O",
|
||
help="Output CSV path (optional)"
|
||
),
|
||
top_k: int = typer.Option(20, "--top-k", "-k", help="Number of top formulations to return"),
|
||
num_seeds: Optional[int] = typer.Option(None, "--num-seeds", "-n", help="Number of seed points from first iteration (default: top_k * 5)"),
|
||
top_per_seed: int = typer.Option(1, "--top-per-seed", "-t", help="Number of local best to keep per seed"),
|
||
step_sizes: Optional[str] = typer.Option(None, "--step-sizes", "-S", help="Comma-separated step sizes (e.g., '0.10,0.02,0.01')"),
|
||
batch_size: int = typer.Option(256, "--batch-size", "-b", help="Prediction batch size"),
|
||
device: str = typer.Option("cuda" if torch.cuda.is_available() else "cpu", "--device", "-d", help="Device"),
|
||
):
|
||
"""
|
||
配方优化程序:寻找最大化目标器官 Biodistribution 的最优 LNP 配方。
|
||
|
||
采用层级搜索策略:
|
||
1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点
|
||
2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优
|
||
3. 这样可以保持搜索的多样性,避免结果集中在单一区域
|
||
|
||
示例:
|
||
python -m app.optimize --smiles "CC(C)..." --organ liver
|
||
python -m app.optimize -s "CC(C)..." -o spleen -k 10 -n 30 -t 2
|
||
python -m app.optimize -s "CC(C)..." -o liver -S "0.10,0.05,0.02"
|
||
"""
|
||
# 验证器官
|
||
if organ not in AVAILABLE_ORGANS:
|
||
logger.error(f"Invalid organ: {organ}. Available: {AVAILABLE_ORGANS}")
|
||
raise typer.Exit(1)
|
||
|
||
# 解析步长
|
||
parsed_step_sizes = None
|
||
if step_sizes:
|
||
try:
|
||
parsed_step_sizes = [float(s.strip()) for s in step_sizes.split(",")]
|
||
except ValueError:
|
||
logger.error(f"Invalid step sizes format: {step_sizes}")
|
||
raise typer.Exit(1)
|
||
|
||
# 加载模型
|
||
logger.info(f"Loading model from {model_path}")
|
||
device = torch.device(device)
|
||
model = load_model(model_path, device)
|
||
|
||
# 执行优化(层级搜索策略)
|
||
results = optimize(
|
||
smiles=smiles,
|
||
organ=organ,
|
||
model=model,
|
||
device=device,
|
||
top_k=top_k,
|
||
num_seeds=num_seeds,
|
||
top_per_seed=top_per_seed,
|
||
step_sizes=parsed_step_sizes,
|
||
batch_size=batch_size,
|
||
)
|
||
|
||
# 格式化并显示结果
|
||
df_results = format_results(results, organ)
|
||
|
||
logger.info(f"\n{'='*60}")
|
||
logger.info(f"TOP {top_k} FORMULATIONS FOR {organ.upper()}")
|
||
logger.info(f"{'='*60}")
|
||
print(df_results.to_string(index=False))
|
||
|
||
# 保存结果
|
||
if output_path:
|
||
df_results.to_csv(output_path, index=False)
|
||
logger.success(f"Results saved to {output_path}")
|
||
|
||
return df_results
|
||
|
||
|
||
if __name__ == "__main__":
|
||
app()
|
||
|