mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-22 01:56:54 +08:00
362 lines
13 KiB
Python
362 lines
13 KiB
Python
"""
|
||
FastAPI 配方优化 API
|
||
|
||
启动服务:
|
||
uvicorn app.api:app --host 0.0.0.0 --port 8000 --reload
|
||
"""
|
||
|
||
import os
|
||
from pathlib import Path
|
||
from typing import List, Dict, Optional
|
||
from contextlib import asynccontextmanager
|
||
|
||
import torch
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel, Field
|
||
from loguru import logger
|
||
|
||
from lnp_ml.config import MODELS_DIR
|
||
from lnp_ml.modeling.predict import load_model
|
||
from app.optimize import (
|
||
optimize,
|
||
format_results,
|
||
AVAILABLE_ORGANS,
|
||
TARGET_BIODIST,
|
||
CompRanges,
|
||
ScoringWeights,
|
||
)
|
||
|
||
|
||
# ============ Pydantic Models ============
|
||
|
||
class CompRangesRequest(BaseModel):
|
||
"""组分范围配置"""
|
||
weight_ratio_min: float = Field(default=0.05, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最小值")
|
||
weight_ratio_max: float = Field(default=0.30, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最大值")
|
||
cationic_mol_min: float = Field(default=0.05, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最小值")
|
||
cationic_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最大值")
|
||
phospholipid_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="磷脂 mol 比例最小值")
|
||
phospholipid_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="磷脂 mol 比例最大值")
|
||
cholesterol_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="胆固醇 mol 比例最小值")
|
||
cholesterol_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="胆固醇 mol 比例最大值")
|
||
peg_mol_min: float = Field(default=0.00, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最小值")
|
||
peg_mol_max: float = Field(default=0.05, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最大值")
|
||
|
||
def to_comp_ranges(self) -> CompRanges:
|
||
"""转换为 CompRanges 对象"""
|
||
return CompRanges(
|
||
weight_ratio_min=self.weight_ratio_min,
|
||
weight_ratio_max=self.weight_ratio_max,
|
||
cationic_mol_min=self.cationic_mol_min,
|
||
cationic_mol_max=self.cationic_mol_max,
|
||
phospholipid_mol_min=self.phospholipid_mol_min,
|
||
phospholipid_mol_max=self.phospholipid_mol_max,
|
||
cholesterol_mol_min=self.cholesterol_mol_min,
|
||
cholesterol_mol_max=self.cholesterol_mol_max,
|
||
peg_mol_min=self.peg_mol_min,
|
||
peg_mol_max=self.peg_mol_max,
|
||
)
|
||
|
||
|
||
class ScoringWeightsRequest(BaseModel):
|
||
"""评分权重配置"""
|
||
biodist_weight: float = Field(default=1.0, ge=0.0, description="目标器官分布权重")
|
||
delivery_weight: float = Field(default=0.0, ge=0.0, description="量化递送权重")
|
||
size_weight: float = Field(default=0.0, ge=0.0, description="粒径权重 (80-150nm)")
|
||
ee_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0], description="EE 分类权重 [class0, class1, class2]")
|
||
pdi_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0, 0.0], description="PDI 分类权重 [class0, class1, class2, class3]")
|
||
toxic_class_weights: List[float] = Field(default=[0.0, 0.0], description="毒性分类权重 [无毒, 有毒]")
|
||
|
||
def to_scoring_weights(self) -> ScoringWeights:
|
||
"""转换为 ScoringWeights 对象"""
|
||
return ScoringWeights(
|
||
biodist_weight=self.biodist_weight,
|
||
delivery_weight=self.delivery_weight,
|
||
size_weight=self.size_weight,
|
||
ee_class_weights=self.ee_class_weights,
|
||
pdi_class_weights=self.pdi_class_weights,
|
||
toxic_class_weights=self.toxic_class_weights,
|
||
)
|
||
|
||
|
||
class OptimizeRequest(BaseModel):
|
||
"""优化请求"""
|
||
smiles: str = Field(..., description="Cationic lipid SMILES string")
|
||
organ: str = Field(..., description="Target organ for optimization")
|
||
top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations to return")
|
||
num_seeds: Optional[int] = Field(default=None, ge=1, le=500, description="Number of seed points from first iteration (default: top_k * 5)")
|
||
top_per_seed: int = Field(default=1, ge=1, le=10, description="Number of local best to keep per seed in refinement")
|
||
step_sizes: Optional[List[float]] = Field(default=None, description="Step sizes for each iteration (default: [0.10, 0.02, 0.01])")
|
||
comp_ranges: Optional[CompRangesRequest] = Field(default=None, description="组分范围配置(默认使用标准范围)")
|
||
routes: Optional[List[str]] = Field(default=None, description="给药途径列表 (default: ['intravenous', 'intramuscular'])")
|
||
scoring_weights: Optional[ScoringWeightsRequest] = Field(default=None, description="评分权重配置(默认仅按 biodist 排序)")
|
||
|
||
class Config:
|
||
json_schema_extra = {
|
||
"example": {
|
||
"smiles": "CC(C)NCCNC(C)C",
|
||
"organ": "liver",
|
||
"top_k": 20,
|
||
"num_seeds": None,
|
||
"top_per_seed": 1,
|
||
"step_sizes": None,
|
||
"comp_ranges": None,
|
||
"routes": None,
|
||
"scoring_weights": None
|
||
}
|
||
}
|
||
|
||
|
||
class FormulationResult(BaseModel):
|
||
"""单个配方结果"""
|
||
rank: int
|
||
target_biodist: float
|
||
composite_score: Optional[float] = None # 综合评分
|
||
cationic_lipid_to_mrna_ratio: float
|
||
cationic_lipid_mol_ratio: float
|
||
phospholipid_mol_ratio: float
|
||
cholesterol_mol_ratio: float
|
||
peg_lipid_mol_ratio: float
|
||
helper_lipid: str
|
||
route: str
|
||
all_biodist: Dict[str, float]
|
||
# 额外预测值
|
||
quantified_delivery: Optional[float] = None
|
||
size: Optional[float] = None
|
||
pdi_class: Optional[int] = None # PDI 分类 (0: <0.2, 1: 0.2-0.3, 2: 0.3-0.4, 3: >0.4)
|
||
ee_class: Optional[int] = None # EE 分类 (0: <80%, 1: 80-90%, 2: >90%)
|
||
toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒)
|
||
|
||
|
||
class OptimizeResponse(BaseModel):
|
||
"""优化响应"""
|
||
smiles: str
|
||
target_organ: str
|
||
formulations: List[FormulationResult]
|
||
message: str
|
||
|
||
|
||
class HealthResponse(BaseModel):
|
||
"""健康检查响应"""
|
||
status: str
|
||
model_loaded: bool
|
||
device: str
|
||
available_organs: List[str]
|
||
|
||
|
||
# ============ Global State ============
|
||
|
||
class ModelState:
|
||
"""模型状态管理"""
|
||
model = None
|
||
device = None
|
||
model_path = None
|
||
|
||
|
||
state = ModelState()
|
||
|
||
|
||
# ============ Lifespan ============
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理:启动时加载模型"""
|
||
# Startup
|
||
logger.info("Starting API server...")
|
||
|
||
# 确定设备
|
||
if torch.cuda.is_available():
|
||
device_str = "cuda"
|
||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||
device_str = "mps"
|
||
else:
|
||
device_str = "cpu"
|
||
|
||
# 可通过环境变量覆盖
|
||
device_str = os.environ.get("DEVICE", device_str)
|
||
state.device = torch.device(device_str)
|
||
logger.info(f"Using device: {state.device}")
|
||
|
||
# 加载模型
|
||
model_path = Path(os.environ.get("MODEL_PATH", MODELS_DIR / "final" / "model.pt"))
|
||
state.model_path = model_path
|
||
|
||
logger.info(f"Loading model from {model_path}...")
|
||
try:
|
||
state.model = load_model(model_path, state.device)
|
||
logger.success("Model loaded successfully!")
|
||
except Exception as e:
|
||
logger.error(f"Failed to load model: {e}")
|
||
raise
|
||
|
||
yield
|
||
|
||
# Shutdown
|
||
logger.info("Shutting down API server...")
|
||
state.model = None
|
||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||
|
||
|
||
# ============ FastAPI App ============
|
||
|
||
app = FastAPI(
|
||
title="LNP 配方优化 API",
|
||
description="基于深度学习的 LNP 纳米颗粒配方优化服务",
|
||
version="1.0.0",
|
||
lifespan=lifespan,
|
||
)
|
||
|
||
# CORS 配置
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# ============ Endpoints ============
|
||
|
||
@app.get("/", response_model=HealthResponse)
|
||
async def health_check():
|
||
"""健康检查"""
|
||
return HealthResponse(
|
||
status="healthy" if state.model is not None else "model_not_loaded",
|
||
model_loaded=state.model is not None,
|
||
device=str(state.device),
|
||
available_organs=AVAILABLE_ORGANS,
|
||
)
|
||
|
||
|
||
@app.get("/organs", response_model=List[str])
|
||
async def get_available_organs():
|
||
"""获取可用的目标器官列表"""
|
||
return AVAILABLE_ORGANS
|
||
|
||
|
||
@app.post("/optimize", response_model=OptimizeResponse)
|
||
async def optimize_formulation(request: OptimizeRequest):
|
||
"""
|
||
执行配方优化
|
||
|
||
通过迭代式 Grid Search 寻找最大化目标器官 Biodistribution 的最优配方。
|
||
"""
|
||
# 验证模型状态
|
||
if state.model is None:
|
||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||
|
||
# 验证器官
|
||
if request.organ not in AVAILABLE_ORGANS:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Invalid organ: {request.organ}. Available: {AVAILABLE_ORGANS}"
|
||
)
|
||
|
||
# 验证 SMILES
|
||
if not request.smiles or len(request.smiles.strip()) == 0:
|
||
raise HTTPException(status_code=400, detail="SMILES string cannot be empty")
|
||
|
||
# 验证 routes
|
||
valid_routes = ["intravenous", "intramuscular"]
|
||
if request.routes is not None:
|
||
for r in request.routes:
|
||
if r not in valid_routes:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Invalid route: {r}. Available: {valid_routes}"
|
||
)
|
||
if len(request.routes) == 0:
|
||
raise HTTPException(status_code=400, detail="At least one route must be specified")
|
||
|
||
logger.info(f"Optimization request: organ={request.organ}, routes={request.routes}, smiles={request.smiles[:50]}...")
|
||
|
||
# 构建组分范围配置(在 try 块外验证,确保返回 400 而非 500)
|
||
comp_ranges = None
|
||
if request.comp_ranges is not None:
|
||
comp_ranges = request.comp_ranges.to_comp_ranges()
|
||
# 验证范围是否合理
|
||
validation_error = comp_ranges.get_validation_error()
|
||
if validation_error:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"组分范围配置无效: {validation_error}"
|
||
)
|
||
|
||
# 构建评分权重配置
|
||
scoring_weights = None
|
||
if request.scoring_weights is not None:
|
||
scoring_weights = request.scoring_weights.to_scoring_weights()
|
||
|
||
try:
|
||
# 执行优化(层级搜索策略)
|
||
results = optimize(
|
||
smiles=request.smiles,
|
||
organ=request.organ,
|
||
model=state.model,
|
||
device=state.device,
|
||
top_k=request.top_k,
|
||
num_seeds=request.num_seeds,
|
||
top_per_seed=request.top_per_seed,
|
||
step_sizes=request.step_sizes,
|
||
comp_ranges=comp_ranges,
|
||
routes=request.routes,
|
||
scoring_weights=scoring_weights,
|
||
batch_size=256,
|
||
)
|
||
|
||
# 用于计算综合评分的权重
|
||
from app.optimize import compute_formulation_score, DEFAULT_SCORING_WEIGHTS
|
||
actual_scoring_weights = scoring_weights if scoring_weights is not None else DEFAULT_SCORING_WEIGHTS
|
||
|
||
# 转换结果
|
||
formulations = []
|
||
for i, f in enumerate(results):
|
||
formulations.append(FormulationResult(
|
||
rank=i + 1,
|
||
target_biodist=f.get_biodist(request.organ),
|
||
composite_score=compute_formulation_score(f, request.organ, actual_scoring_weights),
|
||
cationic_lipid_to_mrna_ratio=f.cationic_lipid_to_mrna_ratio,
|
||
cationic_lipid_mol_ratio=f.cationic_lipid_mol_ratio,
|
||
phospholipid_mol_ratio=f.phospholipid_mol_ratio,
|
||
cholesterol_mol_ratio=f.cholesterol_mol_ratio,
|
||
peg_lipid_mol_ratio=f.peg_lipid_mol_ratio,
|
||
helper_lipid=f.helper_lipid,
|
||
route=f.route,
|
||
all_biodist={
|
||
col.replace("Biodistribution_", ""): f.biodist_predictions.get(col, 0.0)
|
||
for col in TARGET_BIODIST
|
||
},
|
||
# 额外预测值
|
||
quantified_delivery=f.quantified_delivery,
|
||
size=f.size,
|
||
pdi_class=f.pdi_class,
|
||
ee_class=f.ee_class,
|
||
toxic_class=f.toxic_class,
|
||
))
|
||
|
||
logger.success(f"Optimization completed: {len(formulations)} formulations")
|
||
|
||
return OptimizeResponse(
|
||
smiles=request.smiles,
|
||
target_organ=request.organ,
|
||
formulations=formulations,
|
||
message=f"Successfully found top {len(formulations)} formulations for {request.organ}",
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Optimization failed: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(
|
||
"app.api:app",
|
||
host="0.0.0.0",
|
||
port=8000,
|
||
reload=True,
|
||
)
|
||
|