""" FastAPI 配方优化 API 启动服务: uvicorn app.api:app --host 0.0.0.0 --port 8000 --reload """ import os from pathlib import Path from typing import List, Dict, Optional from contextlib import asynccontextmanager import torch from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from loguru import logger from lnp_ml.config import MODELS_DIR from lnp_ml.modeling.predict import load_model from app.optimize import ( optimize, format_results, AVAILABLE_ORGANS, TARGET_BIODIST, CompRanges, ScoringWeights, ) # ============ Pydantic Models ============ class CompRangesRequest(BaseModel): """组分范围配置(mol 比例为百分数 0-100)""" weight_ratio_min: float = Field(default=5.0, ge=1.0, le=50.0, description="阳离子脂质/mRNA 重量比最小值") weight_ratio_max: float = Field(default=30.0, ge=1.0, le=50.0, description="阳离子脂质/mRNA 重量比最大值") cationic_mol_min: float = Field(default=5.0, ge=0.0, le=100.0, description="阳离子脂质 mol 比例最小值 (%)") cationic_mol_max: float = Field(default=80.0, ge=0.0, le=100.0, description="阳离子脂质 mol 比例最大值 (%)") phospholipid_mol_min: float = Field(default=0.0, ge=0.0, le=100.0, description="磷脂 mol 比例最小值 (%)") phospholipid_mol_max: float = Field(default=80.0, ge=0.0, le=100.0, description="磷脂 mol 比例最大值 (%)") cholesterol_mol_min: float = Field(default=0.0, ge=0.0, le=100.0, description="胆固醇 mol 比例最小值 (%)") cholesterol_mol_max: float = Field(default=80.0, ge=0.0, le=100.0, description="胆固醇 mol 比例最大值 (%)") peg_mol_min: float = Field(default=0.0, ge=0.0, le=20.0, description="PEG 脂质 mol 比例最小值 (%)") peg_mol_max: float = Field(default=5.0, ge=0.0, le=20.0, description="PEG 脂质 mol 比例最大值 (%)") def to_comp_ranges(self) -> CompRanges: """转换为 CompRanges 对象""" return CompRanges( weight_ratio_min=self.weight_ratio_min, weight_ratio_max=self.weight_ratio_max, cationic_mol_min=self.cationic_mol_min, cationic_mol_max=self.cationic_mol_max, phospholipid_mol_min=self.phospholipid_mol_min, phospholipid_mol_max=self.phospholipid_mol_max, cholesterol_mol_min=self.cholesterol_mol_min, cholesterol_mol_max=self.cholesterol_mol_max, peg_mol_min=self.peg_mol_min, peg_mol_max=self.peg_mol_max, ) class ScoringWeightsRequest(BaseModel): """评分权重配置""" biodist_weight: float = Field(default=1.0, ge=0.0, description="目标器官分布权重") delivery_weight: float = Field(default=0.0, ge=0.0, description="量化递送权重") size_weight: float = Field(default=0.0, ge=0.0, description="粒径权重 (80-150nm)") ee_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0], description="EE 分类权重 [class0, class1, class2]") pdi_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0, 0.0], description="PDI 分类权重 [class0, class1, class2, class3]") toxic_class_weights: List[float] = Field(default=[0.0, 0.0], description="毒性分类权重 [无毒, 有毒]") def to_scoring_weights(self) -> ScoringWeights: """转换为 ScoringWeights 对象""" return ScoringWeights( biodist_weight=self.biodist_weight, delivery_weight=self.delivery_weight, size_weight=self.size_weight, ee_class_weights=self.ee_class_weights, pdi_class_weights=self.pdi_class_weights, toxic_class_weights=self.toxic_class_weights, ) class OptimizeRequest(BaseModel): """优化请求""" smiles: str = Field(..., description="Cationic lipid SMILES string") organ: str = Field(..., description="Target organ for optimization") top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations to return") num_seeds: Optional[int] = Field(default=None, ge=1, le=500, description="Number of seed points from first iteration (default: top_k * 5)") top_per_seed: int = Field(default=1, ge=1, le=10, description="Number of local best to keep per seed in refinement") step_sizes: Optional[List[float]] = Field(default=None, description="Mol ratio step sizes for each iteration (default: [10, 2, 1])") wr_step_sizes: Optional[List[float]] = Field(default=None, description="Weight ratio step sizes for each iteration (default: [5, 2, 1])") comp_ranges: Optional[CompRangesRequest] = Field(default=None, description="组分范围配置(默认使用标准范围)") routes: Optional[List[str]] = Field(default=None, description="给药途径列表 (default: ['intravenous', 'intramuscular'])") scoring_weights: Optional[ScoringWeightsRequest] = Field(default=None, description="评分权重配置(默认仅按 biodist 排序)") class Config: json_schema_extra = { "example": { "smiles": "CC(C)NCCNC(C)C", "organ": "liver", "top_k": 20, "num_seeds": None, "top_per_seed": 1, "step_sizes": None, "comp_ranges": None, "routes": None, "scoring_weights": None } } class FormulationResult(BaseModel): """单个配方结果""" rank: int target_biodist: float composite_score: Optional[float] = None # 综合评分 cationic_lipid_to_mrna_ratio: float cationic_lipid_mol_ratio: float phospholipid_mol_ratio: float cholesterol_mol_ratio: float peg_lipid_mol_ratio: float helper_lipid: str route: str all_biodist: Dict[str, float] # 额外预测值 quantified_delivery: Optional[float] = None size: Optional[float] = None pdi_class: Optional[int] = None # PDI 分类 (0: <0.2, 1: 0.2-0.3, 2: 0.3-0.4, 3: >0.4) ee_class: Optional[int] = None # EE 分类 (0: <80%, 1: 80-90%, 2: >90%) toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒) class OptimizeResponse(BaseModel): """优化响应""" smiles: str target_organ: str formulations: List[FormulationResult] message: str class HealthResponse(BaseModel): """健康检查响应""" status: str model_loaded: bool device: str available_organs: List[str] # ============ Global State ============ class ModelState: """模型状态管理""" model = None device = None model_path = None state = ModelState() # ============ Lifespan ============ @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理:启动时加载模型""" # Startup logger.info("Starting API server...") # 确定设备 if torch.cuda.is_available(): device_str = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device_str = "mps" else: device_str = "cpu" # 可通过环境变量覆盖 device_str = os.environ.get("DEVICE", device_str) state.device = torch.device(device_str) logger.info(f"Using device: {state.device}") # 加载模型 model_path = Path(os.environ.get("MODEL_PATH", MODELS_DIR / "final" / "model.pt")) state.model_path = model_path logger.info(f"Loading model from {model_path}...") try: state.model = load_model(model_path, state.device) logger.success("Model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {e}") raise yield # Shutdown logger.info("Shutting down API server...") state.model = None torch.cuda.empty_cache() if torch.cuda.is_available() else None # ============ FastAPI App ============ app = FastAPI( title="LNP 配方优化 API", description="基于深度学习的 LNP 纳米颗粒配方优化服务", version="1.0.0", lifespan=lifespan, ) # CORS 配置 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============ Endpoints ============ @app.get("/", response_model=HealthResponse) async def health_check(): """健康检查""" return HealthResponse( status="healthy" if state.model is not None else "model_not_loaded", model_loaded=state.model is not None, device=str(state.device), available_organs=AVAILABLE_ORGANS, ) @app.get("/organs", response_model=List[str]) async def get_available_organs(): """获取可用的目标器官列表""" return AVAILABLE_ORGANS @app.post("/optimize", response_model=OptimizeResponse) async def optimize_formulation(request: OptimizeRequest): """ 执行配方优化 通过迭代式 Grid Search 寻找最大化目标器官 Biodistribution 的最优配方。 """ # 验证模型状态 if state.model is None: raise HTTPException(status_code=503, detail="Model not loaded") # 验证器官 if request.organ not in AVAILABLE_ORGANS: raise HTTPException( status_code=400, detail=f"Invalid organ: {request.organ}. Available: {AVAILABLE_ORGANS}" ) # 验证 SMILES if not request.smiles or len(request.smiles.strip()) == 0: raise HTTPException(status_code=400, detail="SMILES string cannot be empty") # 验证 routes valid_routes = ["intravenous", "intramuscular"] if request.routes is not None: for r in request.routes: if r not in valid_routes: raise HTTPException( status_code=400, detail=f"Invalid route: {r}. Available: {valid_routes}" ) if len(request.routes) == 0: raise HTTPException(status_code=400, detail="At least one route must be specified") logger.info(f"Optimization request: organ={request.organ}, routes={request.routes}, smiles={request.smiles[:50]}...") # 构建组分范围配置(在 try 块外验证,确保返回 400 而非 500) comp_ranges = None if request.comp_ranges is not None: comp_ranges = request.comp_ranges.to_comp_ranges() # 验证范围是否合理 validation_error = comp_ranges.get_validation_error() if validation_error: raise HTTPException( status_code=400, detail=f"组分范围配置无效: {validation_error}" ) # 构建评分权重配置 scoring_weights = None if request.scoring_weights is not None: scoring_weights = request.scoring_weights.to_scoring_weights() try: results = optimize( smiles=request.smiles, organ=request.organ, model=state.model, device=state.device, top_k=request.top_k, num_seeds=request.num_seeds, top_per_seed=request.top_per_seed, step_sizes=request.step_sizes, wr_step_sizes=request.wr_step_sizes, comp_ranges=comp_ranges, routes=request.routes, scoring_weights=scoring_weights, batch_size=256, ) # 用于计算综合评分的权重 from app.optimize import compute_formulation_score, DEFAULT_SCORING_WEIGHTS actual_scoring_weights = scoring_weights if scoring_weights is not None else DEFAULT_SCORING_WEIGHTS # 转换结果 formulations = [] for i, f in enumerate(results): formulations.append(FormulationResult( rank=i + 1, target_biodist=f.get_biodist(request.organ), composite_score=compute_formulation_score(f, request.organ, actual_scoring_weights), cationic_lipid_to_mrna_ratio=f.cationic_lipid_to_mrna_ratio, cationic_lipid_mol_ratio=f.cationic_lipid_mol_ratio, phospholipid_mol_ratio=f.phospholipid_mol_ratio, cholesterol_mol_ratio=f.cholesterol_mol_ratio, peg_lipid_mol_ratio=f.peg_lipid_mol_ratio, helper_lipid=f.helper_lipid, route=f.route, all_biodist={ col.replace("Biodistribution_", ""): f.biodist_predictions.get(col, 0.0) for col in TARGET_BIODIST }, # 额外预测值 quantified_delivery=f.quantified_delivery, size=f.size, pdi_class=f.pdi_class, ee_class=f.ee_class, toxic_class=f.toxic_class, )) logger.success(f"Optimization completed: {len(formulations)} formulations") return OptimizeResponse( smiles=request.smiles, target_organ=request.organ, formulations=formulations, message=f"Successfully found top {len(formulations)} formulations for {request.organ}", ) except Exception as e: logger.error(f"Optimization failed: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run( "app.api:app", host="0.0.0.0", port=8000, reload=True, )