mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 01:27:00 +08:00
Add LNP fomular optimization
This commit is contained in:
parent
39a14e4274
commit
982e98cced
6
Makefile
6
Makefile
@ -174,6 +174,12 @@ predict: requirements
|
||||
test: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.predict test $(DEVICE_FLAG)
|
||||
|
||||
## Formulation optimization: find optimal LNP formulation for target organ
|
||||
## Usage: make optimize SMILES="CC(C)..." ORGAN=liver
|
||||
.PHONY: optimize
|
||||
optimize: requirements
|
||||
$(PYTHON_INTERPRETER) -m app.optimize --smiles "$(SMILES)" --organ $(ORGAN) $(DEVICE_FLAG)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Self Documenting Commands #
|
||||
|
||||
@ -156,5 +156,12 @@ python -m lnp_ml.modeling.train \
|
||||
└── plots.py <- Code to create visualizations
|
||||
```
|
||||
|
||||
|
||||
### 配方筛选
|
||||
|
||||
```
|
||||
make optimize SMILES="CC(C)NCCNC(C)C" ORGAN=liver
|
||||
```
|
||||
|
||||
--------
|
||||
|
||||
|
||||
62
app/PARAM.md
Normal file
62
app/PARAM.md
Normal file
@ -0,0 +1,62 @@
|
||||
## Possible Values
|
||||
|
||||
# comp token([B, 5], the sum of the latter four ratio is always 1)
|
||||
Cationic_Lipid_to_mRNA_weight_ratio(float, Min: 0.05, Max: 0.3, Step Size: 0.01)
|
||||
Cationic_Lipid_Mol_Ratio(float, Min: 0.05, Max: 0.8, Step Size: 0.01)
|
||||
Phospholipid_Mol_Ratio(float, Min: 0, Max: 0.8, Step Size: 0.01)
|
||||
Cholesterol_Mol_Ratio(float, Min: 0, Max: 0.8, Step Size: 0.01)
|
||||
PEG_Lipid_Mol_Ratio(float, Min: 0, Max: 0.05, Step Size: 0.01)
|
||||
|
||||
# phys token([B, 12])
|
||||
Purity_Pure(one-hot for Purity, always Pure)
|
||||
Purity_Crude(one-hot for Purity, always Pure)
|
||||
Mix_type_Microfluidic(one-hot for Mix_type, always Microfluidic)
|
||||
Mix_type_Microfluidic(one-hot for Mix_type, always Microfluidic)
|
||||
Cargo_type_mRNA(one-hot for Cargo_type, always mRNA)
|
||||
Cargo_type_pDNA(one-hot for Cargo_type, always mRNA)
|
||||
Cargo_type_siRNA(one-hot for Cargo_type, always mRNA)
|
||||
Target_or_delivered_gene_FFL(one-hot for Target_or_delivered_gene, always FFL)
|
||||
Target_or_delivered_gene_Peptide_barcode(one-hot for Target_or_delivered_gene, always FFL)
|
||||
Target_or_delivered_gene_hEPO(one-hot for Target_or_delivered_gene, always FFL)
|
||||
Target_or_delivered_gene_FVII(one-hot for Target_or_delivered_gene, always FFL)
|
||||
Target_or_delivered_gene_GFP(one-hot for Target_or_delivered_gene, always FFL)
|
||||
|
||||
# help token([B, 4])
|
||||
Helper_lipid_ID_DOPE(one-hot for Helper_lipid_ID, one of {DOPE, DSPC, DOTAP})
|
||||
Helper_lipid_ID_DOTAP(one-hot for Helper_lipid_ID, one of {DOPE, DSPC, DOTAP})
|
||||
Helper_lipid_ID_DSPC(one-hot for Helper_lipid_ID, one of {DOPE, DSPC, DOTAP})
|
||||
Helper_lipid_ID_MDOA(one-hot for Helper_lipid_ID, one of {DOPE, DSPC, DOTAP})
|
||||
|
||||
# exp token([B, 32])
|
||||
Model_type_A549(one-hot for Model_type, always Mouse)
|
||||
Model_type_BDMC(one-hot for Model_type, always Mouse)
|
||||
Model_type_BMDM(one-hot for Model_type, always Mouse)
|
||||
Model_type_HBEC_ALI(one-hot for Model_type, always Mouse)
|
||||
Model_type_HEK293T(one-hot for Model_type, always Mouse)
|
||||
Model_type_HeLa(one-hot for Model_type, always Mouse)
|
||||
Model_type_IGROV1(one-hot for Model_type, always Mouse)
|
||||
Model_type_Mouse(one-hot for Model_type, always Mouse)
|
||||
Model_type_RAW264p7(one-hot for Model_type, always Mouse)
|
||||
Delivery_target_dendritic_cell(one-hot for Delivery_target, always body)
|
||||
Delivery_target_generic_cell(one-hot for Delivery_target, always body)
|
||||
Delivery_target_liver(one-hot for Delivery_target, always body)
|
||||
Delivery_target_lung(one-hot for Delivery_target, always body)
|
||||
Delivery_target_lung_epithelium(one-hot for Delivery_target, always body)
|
||||
Delivery_target_macrophage(one-hot for Delivery_target, always body)
|
||||
Delivery_target_muscle(one-hot for Delivery_target, always body)
|
||||
Delivery_target_spleen(one-hot for Delivery_target, always body)
|
||||
Delivery_target_body(one-hot for Delivery_target, always body)
|
||||
Route_of_administration_in_vitro(one-hot for Route_of_administration, one of {Intravenous, Intramuscular})
|
||||
Route_of_administration_intravenous(one-hot for Route_of_administration, one of {Intravenous, Intramuscular})
|
||||
Route_of_administration_intramuscular(one-hot for Route_of_administration, one of {Intravenous, Intramuscular})
|
||||
Route_of_administration_intratracheal(one-hot for Route_of_administration, one of {Intravenous, Intramuscular})
|
||||
Sample_organization_type_individual(one-hot for Sample_organization_type, always Individual)
|
||||
Sample_organization_type_barcoded(one-hot for Sample_organization_type, always Individual)
|
||||
Value_name_log_luminescence(one-hot for Value_name, always luminescence)
|
||||
Value_name_luminescence(one-hot for Value_name, always luminescence)
|
||||
Value_name_FFL_silencing(one-hot for Value_name, always luminescence)
|
||||
Value_name_Peptide_abundance(one-hot for Value_name, always luminescence)
|
||||
Value_name_hEPO(one-hot for Value_name, always luminescence)
|
||||
Value_name_FVII_silencing(one-hot for Value_name, always luminescence)
|
||||
Value_name_GFP_delivery(one-hot for Value_name, always luminescence)
|
||||
Value_name_Discretized_luminescence(one-hot for Value_name, always luminescence)
|
||||
2
app/__init__.py
Normal file
2
app/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""LNP 配方优化应用"""
|
||||
|
||||
532
app/optimize.py
Normal file
532
app/optimize.py
Normal file
@ -0,0 +1,532 @@
|
||||
"""
|
||||
配方优化模拟程序
|
||||
|
||||
通过迭代式 Grid Search 寻找最优 LNP 配方,最大化目标器官的 Biodistribution。
|
||||
|
||||
使用方法:
|
||||
python -m app.optimize --smiles "CC(C)..." --organ liver
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import typer
|
||||
|
||||
from lnp_ml.config import MODELS_DIR
|
||||
from lnp_ml.dataset import (
|
||||
LNPDataset,
|
||||
LNPDatasetConfig,
|
||||
collate_fn,
|
||||
SMILES_COL,
|
||||
COMP_COLS,
|
||||
HELP_COLS,
|
||||
TARGET_BIODIST,
|
||||
get_phys_cols,
|
||||
get_exp_cols,
|
||||
)
|
||||
from lnp_ml.modeling.predict import load_model
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
# ============ 参数配置 ============
|
||||
|
||||
# 可用的目标器官
|
||||
AVAILABLE_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
|
||||
|
||||
# comp token 参数范围
|
||||
COMP_PARAM_RANGES = {
|
||||
"Cationic_Lipid_to_mRNA_weight_ratio": (0.05, 0.30),
|
||||
"Cationic_Lipid_Mol_Ratio": (0.05, 0.80),
|
||||
"Phospholipid_Mol_Ratio": (0.00, 0.80),
|
||||
"Cholesterol_Mol_Ratio": (0.00, 0.80),
|
||||
"PEG_Lipid_Mol_Ratio": (0.00, 0.05),
|
||||
}
|
||||
|
||||
# 最小 step size
|
||||
MIN_STEP_SIZE = 0.01
|
||||
|
||||
# 迭代策略:每个迭代的 step_size
|
||||
ITERATION_STEP_SIZES = [0.10, 0.02, 0.01]
|
||||
|
||||
# Helper lipid 选项
|
||||
HELPER_LIPID_OPTIONS = ["DOPE", "DSPC", "DOTAP"]
|
||||
|
||||
# Route of administration 选项
|
||||
ROUTE_OPTIONS = ["intravenous", "intramuscular"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formulation:
|
||||
"""配方数据结构"""
|
||||
# comp token
|
||||
cationic_lipid_to_mrna_ratio: float
|
||||
cationic_lipid_mol_ratio: float
|
||||
phospholipid_mol_ratio: float
|
||||
cholesterol_mol_ratio: float
|
||||
peg_lipid_mol_ratio: float
|
||||
# 离散选项
|
||||
helper_lipid: str = "DOPE"
|
||||
route: str = "intravenous"
|
||||
# 预测结果(填充后)
|
||||
biodist_predictions: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"Cationic_Lipid_to_mRNA_weight_ratio": self.cationic_lipid_to_mrna_ratio,
|
||||
"Cationic_Lipid_Mol_Ratio": self.cationic_lipid_mol_ratio,
|
||||
"Phospholipid_Mol_Ratio": self.phospholipid_mol_ratio,
|
||||
"Cholesterol_Mol_Ratio": self.cholesterol_mol_ratio,
|
||||
"PEG_Lipid_Mol_Ratio": self.peg_lipid_mol_ratio,
|
||||
"helper_lipid": self.helper_lipid,
|
||||
"route": self.route,
|
||||
}
|
||||
|
||||
def get_biodist(self, organ: str) -> float:
|
||||
"""获取指定器官的 biodistribution 预测值"""
|
||||
col = f"Biodistribution_{organ}"
|
||||
return self.biodist_predictions.get(col, 0.0)
|
||||
|
||||
|
||||
def generate_grid_values(
|
||||
center: float,
|
||||
step_size: float,
|
||||
min_val: float,
|
||||
max_val: float,
|
||||
radius: int = 2,
|
||||
) -> List[float]:
|
||||
"""
|
||||
围绕中心点生成网格值。
|
||||
|
||||
Args:
|
||||
center: 中心值
|
||||
step_size: 步长
|
||||
min_val: 最小值
|
||||
max_val: 最大值
|
||||
radius: 扩展半径(生成 2*radius+1 个点)
|
||||
|
||||
Returns:
|
||||
网格值列表
|
||||
"""
|
||||
values = []
|
||||
for i in range(-radius, radius + 1):
|
||||
val = center + i * step_size
|
||||
if min_val <= val <= max_val:
|
||||
values.append(round(val, 4))
|
||||
return sorted(set(values))
|
||||
|
||||
|
||||
def generate_initial_grid(step_size: float) -> List[Tuple[float, float, float, float, float]]:
|
||||
"""
|
||||
生成初始搜索网格(满足 mol ratio 和为 1 的约束)。
|
||||
|
||||
Returns:
|
||||
List of (cationic_ratio, cationic_mol, phospholipid_mol, cholesterol_mol, peg_mol)
|
||||
"""
|
||||
grid = []
|
||||
|
||||
# Cationic_Lipid_to_mRNA_weight_ratio
|
||||
weight_ratios = np.arange(0.05, 0.31, step_size)
|
||||
|
||||
# PEG: 单独处理,范围很小
|
||||
peg_values = np.arange(0.00, 0.06, MIN_STEP_SIZE) # PEG 始终用 0.01 步长
|
||||
|
||||
# 其他三个 mol ratio 需要满足和为 1 - PEG
|
||||
mol_step = step_size
|
||||
|
||||
for weight_ratio in weight_ratios:
|
||||
for peg in peg_values:
|
||||
remaining = 1.0 - peg
|
||||
# 生成满足约束的组合
|
||||
for cationic_mol in np.arange(0.05, min(0.81, remaining + 0.001), mol_step):
|
||||
for phospholipid_mol in np.arange(0.00, min(0.81, remaining - cationic_mol + 0.001), mol_step):
|
||||
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
|
||||
# 检查约束
|
||||
if 0.00 <= cholesterol_mol <= 0.80:
|
||||
grid.append((
|
||||
round(weight_ratio, 4),
|
||||
round(cationic_mol, 4),
|
||||
round(phospholipid_mol, 4),
|
||||
round(cholesterol_mol, 4),
|
||||
round(peg, 4),
|
||||
))
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def generate_refined_grid(
|
||||
seeds: List[Formulation],
|
||||
step_size: float,
|
||||
radius: int = 2,
|
||||
) -> List[Tuple[float, float, float, float, float]]:
|
||||
"""
|
||||
围绕种子点生成精细化网格。
|
||||
|
||||
Args:
|
||||
seeds: 种子配方列表
|
||||
step_size: 步长
|
||||
radius: 扩展半径
|
||||
|
||||
Returns:
|
||||
新的网格点列表
|
||||
"""
|
||||
grid_set = set()
|
||||
|
||||
for seed in seeds:
|
||||
# Weight ratio
|
||||
weight_ratios = generate_grid_values(
|
||||
seed.cationic_lipid_to_mrna_ratio, step_size, 0.05, 0.30, radius
|
||||
)
|
||||
|
||||
# PEG (始终用最小步长)
|
||||
peg_values = generate_grid_values(
|
||||
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, 0.00, 0.05, radius
|
||||
)
|
||||
|
||||
# Mol ratios
|
||||
cationic_mols = generate_grid_values(
|
||||
seed.cationic_lipid_mol_ratio, step_size, 0.05, 0.80, radius
|
||||
)
|
||||
phospholipid_mols = generate_grid_values(
|
||||
seed.phospholipid_mol_ratio, step_size, 0.00, 0.80, radius
|
||||
)
|
||||
|
||||
for weight_ratio in weight_ratios:
|
||||
for peg in peg_values:
|
||||
remaining = 1.0 - peg
|
||||
for cationic_mol in cationic_mols:
|
||||
for phospholipid_mol in phospholipid_mols:
|
||||
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
|
||||
# 检查约束
|
||||
if (0.05 <= cationic_mol <= 0.80 and
|
||||
0.00 <= phospholipid_mol <= 0.80 and
|
||||
0.00 <= cholesterol_mol <= 0.80 and
|
||||
0.00 <= peg <= 0.05):
|
||||
grid_set.add((
|
||||
round(weight_ratio, 4),
|
||||
round(cationic_mol, 4),
|
||||
round(phospholipid_mol, 4),
|
||||
round(cholesterol_mol, 4),
|
||||
round(peg, 4),
|
||||
))
|
||||
|
||||
return list(grid_set)
|
||||
|
||||
|
||||
def create_dataframe_from_formulations(
|
||||
smiles: str,
|
||||
grid: List[Tuple[float, float, float, float, float]],
|
||||
helper_lipids: List[str],
|
||||
routes: List[str],
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
从配方网格创建 DataFrame。
|
||||
|
||||
使用固定的 phys token(Pure, Microfluidic, mRNA, FFL)和 exp token(Mouse, body, luminescence)。
|
||||
"""
|
||||
rows = []
|
||||
|
||||
for comp_values in grid:
|
||||
for helper in helper_lipids:
|
||||
for route in routes:
|
||||
row = {
|
||||
SMILES_COL: smiles,
|
||||
# comp token
|
||||
"Cationic_Lipid_to_mRNA_weight_ratio": comp_values[0],
|
||||
"Cationic_Lipid_Mol_Ratio": comp_values[1],
|
||||
"Phospholipid_Mol_Ratio": comp_values[2],
|
||||
"Cholesterol_Mol_Ratio": comp_values[3],
|
||||
"PEG_Lipid_Mol_Ratio": comp_values[4],
|
||||
# phys token (固定值)
|
||||
"Purity_Pure": 1.0,
|
||||
"Purity_Crude": 0.0,
|
||||
"Mix_type_Microfluidic": 1.0,
|
||||
"Mix_type_Pipetting": 0.0,
|
||||
"Cargo_type_mRNA": 1.0,
|
||||
"Cargo_type_pDNA": 0.0,
|
||||
"Cargo_type_siRNA": 0.0,
|
||||
"Target_or_delivered_gene_FFL": 1.0,
|
||||
"Target_or_delivered_gene_Peptide_barcode": 0.0,
|
||||
"Target_or_delivered_gene_hEPO": 0.0,
|
||||
"Target_or_delivered_gene_FVII": 0.0,
|
||||
"Target_or_delivered_gene_GFP": 0.0,
|
||||
# help token
|
||||
"Helper_lipid_ID_DOPE": 1.0 if helper == "DOPE" else 0.0,
|
||||
"Helper_lipid_ID_DOTAP": 1.0 if helper == "DOTAP" else 0.0,
|
||||
"Helper_lipid_ID_DSPC": 1.0 if helper == "DSPC" else 0.0,
|
||||
"Helper_lipid_ID_MDOA": 0.0, # 不使用
|
||||
# exp token (固定值)
|
||||
"Model_type_Mouse": 1.0,
|
||||
"Delivery_target_body": 1.0,
|
||||
f"Route_of_administration_{route}": 1.0,
|
||||
"Batch_or_individual_or_barcoded_Individual": 1.0,
|
||||
"Value_name_luminescence": 1.0,
|
||||
# 存储配方元信息
|
||||
"_helper_lipid": helper,
|
||||
"_route": route,
|
||||
}
|
||||
# 其他 exp token 默认为 0
|
||||
for col in get_exp_cols():
|
||||
if col not in row:
|
||||
row[col] = 0.0
|
||||
|
||||
rows.append(row)
|
||||
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def predict_biodist(
|
||||
model: torch.nn.Module,
|
||||
df: pd.DataFrame,
|
||||
device: torch.device,
|
||||
batch_size: int = 256,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
使用模型预测 biodistribution。
|
||||
|
||||
Returns:
|
||||
添加了预测列的 DataFrame
|
||||
"""
|
||||
dataset = LNPDataset(df)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
all_biodist_preds = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
|
||||
outputs = model(smiles, tabular)
|
||||
|
||||
# biodist 输出是 softmax 后的概率分布 [B, 7]
|
||||
biodist_pred = outputs["biodist"].cpu().numpy()
|
||||
all_biodist_preds.append(biodist_pred)
|
||||
|
||||
biodist_preds = np.concatenate(all_biodist_preds, axis=0)
|
||||
|
||||
# 添加到 DataFrame
|
||||
for i, col in enumerate(TARGET_BIODIST):
|
||||
df[f"pred_{col}"] = biodist_preds[:, i]
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def select_top_k(
|
||||
df: pd.DataFrame,
|
||||
organ: str,
|
||||
k: int = 20,
|
||||
) -> List[Formulation]:
|
||||
"""
|
||||
选择 top-k 配方。
|
||||
|
||||
Args:
|
||||
df: 包含预测结果的 DataFrame
|
||||
organ: 目标器官
|
||||
k: 选择数量
|
||||
|
||||
Returns:
|
||||
Top-k 配方列表
|
||||
"""
|
||||
pred_col = f"pred_Biodistribution_{organ}"
|
||||
if pred_col not in df.columns:
|
||||
raise ValueError(f"Prediction column {pred_col} not found")
|
||||
|
||||
# 排序并去重
|
||||
df_sorted = df.sort_values(pred_col, ascending=False)
|
||||
|
||||
# 创建配方对象
|
||||
formulations = []
|
||||
seen = set()
|
||||
|
||||
for _, row in df_sorted.iterrows():
|
||||
key = (
|
||||
row["Cationic_Lipid_to_mRNA_weight_ratio"],
|
||||
row["Cationic_Lipid_Mol_Ratio"],
|
||||
row["Phospholipid_Mol_Ratio"],
|
||||
row["Cholesterol_Mol_Ratio"],
|
||||
row["PEG_Lipid_Mol_Ratio"],
|
||||
row["_helper_lipid"],
|
||||
row["_route"],
|
||||
)
|
||||
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
formulation = Formulation(
|
||||
cationic_lipid_to_mrna_ratio=row["Cationic_Lipid_to_mRNA_weight_ratio"],
|
||||
cationic_lipid_mol_ratio=row["Cationic_Lipid_Mol_Ratio"],
|
||||
phospholipid_mol_ratio=row["Phospholipid_Mol_Ratio"],
|
||||
cholesterol_mol_ratio=row["Cholesterol_Mol_Ratio"],
|
||||
peg_lipid_mol_ratio=row["PEG_Lipid_Mol_Ratio"],
|
||||
helper_lipid=row["_helper_lipid"],
|
||||
route=row["_route"],
|
||||
biodist_predictions={
|
||||
col: row[f"pred_{col}"] for col in TARGET_BIODIST
|
||||
},
|
||||
)
|
||||
formulations.append(formulation)
|
||||
|
||||
if len(formulations) >= k:
|
||||
break
|
||||
|
||||
return formulations
|
||||
|
||||
|
||||
def optimize(
|
||||
smiles: str,
|
||||
organ: str,
|
||||
model: torch.nn.Module,
|
||||
device: torch.device,
|
||||
top_k: int = 20,
|
||||
batch_size: int = 256,
|
||||
) -> List[Formulation]:
|
||||
"""
|
||||
执行配方优化。
|
||||
|
||||
Args:
|
||||
smiles: SMILES 字符串
|
||||
organ: 目标器官
|
||||
model: 训练好的模型
|
||||
device: 计算设备
|
||||
top_k: 每轮保留的最优配方数
|
||||
batch_size: 预测批次大小
|
||||
|
||||
Returns:
|
||||
最终 top-k 配方列表
|
||||
"""
|
||||
logger.info(f"Starting optimization for organ: {organ}")
|
||||
logger.info(f"SMILES: {smiles}")
|
||||
|
||||
seeds = None
|
||||
|
||||
for iteration, step_size in enumerate(ITERATION_STEP_SIZES):
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Iteration {iteration + 1}/{len(ITERATION_STEP_SIZES)}, step_size={step_size}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
# 生成网格
|
||||
if seeds is None:
|
||||
# 第一次迭代:生成完整初始网格
|
||||
logger.info("Generating initial grid...")
|
||||
grid = generate_initial_grid(step_size)
|
||||
else:
|
||||
# 后续迭代:围绕种子点精细化
|
||||
logger.info(f"Generating refined grid around {len(seeds)} seeds...")
|
||||
grid = generate_refined_grid(seeds, step_size, radius=2)
|
||||
|
||||
logger.info(f"Grid size: {len(grid)} comp combinations")
|
||||
|
||||
# 扩展到所有 helper lipid 和 route 组合
|
||||
total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(ROUTE_OPTIONS)
|
||||
logger.info(f"Total combinations: {total_combinations}")
|
||||
|
||||
# 创建 DataFrame
|
||||
df = create_dataframe_from_formulations(
|
||||
smiles, grid, HELPER_LIPID_OPTIONS, ROUTE_OPTIONS
|
||||
)
|
||||
|
||||
# 预测
|
||||
logger.info("Running predictions...")
|
||||
df = predict_biodist(model, df, device, batch_size)
|
||||
|
||||
# 选择 top-k
|
||||
seeds = select_top_k(df, organ, top_k)
|
||||
|
||||
# 显示当前最优
|
||||
best = seeds[0]
|
||||
logger.info(f"Current best Biodistribution_{organ}: {best.get_biodist(organ):.4f}")
|
||||
logger.info(f"Best formulation: {best.to_dict()}")
|
||||
|
||||
return seeds
|
||||
|
||||
|
||||
def format_results(formulations: List[Formulation], organ: str) -> pd.DataFrame:
|
||||
"""格式化结果为 DataFrame"""
|
||||
rows = []
|
||||
for i, f in enumerate(formulations):
|
||||
row = {
|
||||
"rank": i + 1,
|
||||
f"Biodistribution_{organ}": f.get_biodist(organ),
|
||||
**f.to_dict(),
|
||||
}
|
||||
# 添加其他器官的预测
|
||||
for col in TARGET_BIODIST:
|
||||
if col != f"Biodistribution_{organ}":
|
||||
row[col] = f.biodist_predictions.get(col, 0.0)
|
||||
rows.append(row)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
smiles: str = typer.Option(..., "--smiles", "-s", help="Cationic lipid SMILES string"),
|
||||
organ: str = typer.Option(..., "--organ", "-o", help=f"Target organ: {AVAILABLE_ORGANS}"),
|
||||
model_path: Path = typer.Option(
|
||||
MODELS_DIR / "final" / "model.pt",
|
||||
"--model", "-m",
|
||||
help="Path to trained model checkpoint"
|
||||
),
|
||||
output_path: Optional[Path] = typer.Option(
|
||||
None, "--output", "-O",
|
||||
help="Output CSV path (optional)"
|
||||
),
|
||||
top_k: int = typer.Option(20, "--top-k", "-k", help="Number of top formulations to return"),
|
||||
batch_size: int = typer.Option(256, "--batch-size", "-b", help="Prediction batch size"),
|
||||
device: str = typer.Option("cuda" if torch.cuda.is_available() else "cpu", "--device", "-d", help="Device"),
|
||||
):
|
||||
"""
|
||||
配方优化程序:寻找最大化目标器官 Biodistribution 的最优 LNP 配方。
|
||||
|
||||
示例:
|
||||
python -m app.optimize --smiles "CC(C)..." --organ liver
|
||||
python -m app.optimize -s "CC(C)..." -o spleen -k 10
|
||||
"""
|
||||
# 验证器官
|
||||
if organ not in AVAILABLE_ORGANS:
|
||||
logger.error(f"Invalid organ: {organ}. Available: {AVAILABLE_ORGANS}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# 加载模型
|
||||
logger.info(f"Loading model from {model_path}")
|
||||
device = torch.device(device)
|
||||
model = load_model(model_path, device)
|
||||
|
||||
# 执行优化
|
||||
results = optimize(
|
||||
smiles=smiles,
|
||||
organ=organ,
|
||||
model=model,
|
||||
device=device,
|
||||
top_k=top_k,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
# 格式化并显示结果
|
||||
df_results = format_results(results, organ)
|
||||
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"TOP {top_k} FORMULATIONS FOR {organ.upper()}")
|
||||
logger.info(f"{'='*60}")
|
||||
print(df_results.to_string(index=False))
|
||||
|
||||
# 保存结果
|
||||
if output_path:
|
||||
df_results.to_csv(output_path, index=False)
|
||||
logger.success(f"Results saved to {output_path}")
|
||||
|
||||
return df_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user