lnp_ml/app/api.py

365 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FastAPI 配方优化 API
启动服务:
uvicorn app.api:app --host 0.0.0.0 --port 8000 --reload
"""
import os
from pathlib import Path
from typing import List, Dict, Optional
from contextlib import asynccontextmanager
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from loguru import logger
from lnp_ml.config import MODELS_DIR
from lnp_ml.modeling.predict import load_model
from app.optimize import (
optimize,
format_results,
AVAILABLE_ORGANS,
TARGET_BIODIST,
CompRanges,
ScoringWeights,
)
# ============ Pydantic Models ============
class CompRangesRequest(BaseModel):
"""组分范围配置mol 比例为百分数 0-100"""
weight_ratio_min: float = Field(default=5.0, ge=1.0, le=50.0, description="阳离子脂质/mRNA 重量比最小值")
weight_ratio_max: float = Field(default=30.0, ge=1.0, le=50.0, description="阳离子脂质/mRNA 重量比最大值")
cationic_mol_min: float = Field(default=5.0, ge=0.0, le=100.0, description="阳离子脂质 mol 比例最小值 (%)")
cationic_mol_max: float = Field(default=80.0, ge=0.0, le=100.0, description="阳离子脂质 mol 比例最大值 (%)")
phospholipid_mol_min: float = Field(default=0.0, ge=0.0, le=100.0, description="磷脂 mol 比例最小值 (%)")
phospholipid_mol_max: float = Field(default=80.0, ge=0.0, le=100.0, description="磷脂 mol 比例最大值 (%)")
cholesterol_mol_min: float = Field(default=0.0, ge=0.0, le=100.0, description="胆固醇 mol 比例最小值 (%)")
cholesterol_mol_max: float = Field(default=80.0, ge=0.0, le=100.0, description="胆固醇 mol 比例最大值 (%)")
peg_mol_min: float = Field(default=0.0, ge=0.0, le=20.0, description="PEG 脂质 mol 比例最小值 (%)")
peg_mol_max: float = Field(default=5.0, ge=0.0, le=20.0, description="PEG 脂质 mol 比例最大值 (%)")
def to_comp_ranges(self) -> CompRanges:
"""转换为 CompRanges 对象"""
return CompRanges(
weight_ratio_min=self.weight_ratio_min,
weight_ratio_max=self.weight_ratio_max,
cationic_mol_min=self.cationic_mol_min,
cationic_mol_max=self.cationic_mol_max,
phospholipid_mol_min=self.phospholipid_mol_min,
phospholipid_mol_max=self.phospholipid_mol_max,
cholesterol_mol_min=self.cholesterol_mol_min,
cholesterol_mol_max=self.cholesterol_mol_max,
peg_mol_min=self.peg_mol_min,
peg_mol_max=self.peg_mol_max,
)
class ScoringWeightsRequest(BaseModel):
"""评分权重配置"""
biodist_weight: float = Field(default=1.0, ge=0.0, description="目标器官分布权重")
delivery_weight: float = Field(default=0.0, ge=0.0, description="量化递送权重")
size_weight: float = Field(default=0.0, ge=0.0, description="粒径权重 (80-150nm)")
ee_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0], description="EE 分类权重 [class0, class1, class2]")
pdi_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0, 0.0], description="PDI 分类权重 [class0, class1, class2, class3]")
toxic_class_weights: List[float] = Field(default=[0.0, 0.0], description="毒性分类权重 [无毒, 有毒]")
def to_scoring_weights(self) -> ScoringWeights:
"""转换为 ScoringWeights 对象"""
return ScoringWeights(
biodist_weight=self.biodist_weight,
delivery_weight=self.delivery_weight,
size_weight=self.size_weight,
ee_class_weights=self.ee_class_weights,
pdi_class_weights=self.pdi_class_weights,
toxic_class_weights=self.toxic_class_weights,
)
class OptimizeRequest(BaseModel):
"""优化请求"""
smiles: str = Field(..., description="Cationic lipid SMILES string")
organ: str = Field(..., description="Target organ for optimization")
top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations to return")
num_seeds: Optional[int] = Field(default=None, ge=1, le=500, description="Number of seed points from first iteration (default: top_k * 5)")
top_per_seed: int = Field(default=1, ge=1, le=10, description="Number of local best to keep per seed in refinement")
step_sizes: Optional[List[float]] = Field(default=None, description="Mol ratio step sizes for each iteration (default: [10, 2, 1])")
wr_step_sizes: Optional[List[float]] = Field(default=None, description="Weight ratio step sizes for each iteration (default: [5, 2, 1])")
comp_ranges: Optional[CompRangesRequest] = Field(default=None, description="组分范围配置(默认使用标准范围)")
routes: Optional[List[str]] = Field(default=None, description="给药途径列表 (default: ['intravenous', 'intramuscular'])")
scoring_weights: Optional[ScoringWeightsRequest] = Field(default=None, description="评分权重配置(默认仅按 biodist 排序)")
class Config:
json_schema_extra = {
"example": {
"smiles": "CC(C)NCCNC(C)C",
"organ": "liver",
"top_k": 20,
"num_seeds": None,
"top_per_seed": 1,
"step_sizes": None,
"comp_ranges": None,
"routes": None,
"scoring_weights": None
}
}
class FormulationResult(BaseModel):
"""单个配方结果"""
rank: int
target_biodist: float
composite_score: Optional[float] = None # 综合评分
cationic_lipid_to_mrna_ratio: float
cationic_lipid_mol_ratio: float
phospholipid_mol_ratio: float
cholesterol_mol_ratio: float
peg_lipid_mol_ratio: float
helper_lipid: str
route: str
all_biodist: Dict[str, float]
# 额外预测值
quantified_delivery: Optional[float] = None
unnormalized_delivery: Optional[float] = None # 反推的原始递送值z-score 逆变换)
size: Optional[float] = None
pdi_class: Optional[int] = None # PDI 分类 (0: <0.2, 1: 0.2-0.3, 2: 0.3-0.4, 3: >0.4)
ee_class: Optional[int] = None # EE 分类 (0: <80%, 1: 80-90%, 2: >90%)
toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒)
class OptimizeResponse(BaseModel):
"""优化响应"""
smiles: str
target_organ: str
formulations: List[FormulationResult]
message: str
class HealthResponse(BaseModel):
"""健康检查响应"""
status: str
model_loaded: bool
device: str
available_organs: List[str]
# ============ Global State ============
class ModelState:
"""模型状态管理"""
model = None
device = None
model_path = None
state = ModelState()
# ============ Lifespan ============
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:启动时加载模型"""
# Startup
logger.info("Starting API server...")
# 确定设备
if torch.cuda.is_available():
device_str = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device_str = "mps"
else:
device_str = "cpu"
# 可通过环境变量覆盖
device_str = os.environ.get("DEVICE", device_str)
state.device = torch.device(device_str)
logger.info(f"Using device: {state.device}")
# 加载模型
model_path = Path(os.environ.get("MODEL_PATH", MODELS_DIR / "final" / "model.pt"))
state.model_path = model_path
logger.info(f"Loading model from {model_path}...")
try:
state.model = load_model(model_path, state.device)
logger.success("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
# Shutdown
logger.info("Shutting down API server...")
state.model = None
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ============ FastAPI App ============
app = FastAPI(
title="LNP 配方优化 API",
description="基于深度学习的 LNP 纳米颗粒配方优化服务",
version="1.0.0",
lifespan=lifespan,
)
# CORS 配置
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============ Endpoints ============
@app.get("/", response_model=HealthResponse)
async def health_check():
"""健康检查"""
return HealthResponse(
status="healthy" if state.model is not None else "model_not_loaded",
model_loaded=state.model is not None,
device=str(state.device),
available_organs=AVAILABLE_ORGANS,
)
@app.get("/organs", response_model=List[str])
async def get_available_organs():
"""获取可用的目标器官列表"""
return AVAILABLE_ORGANS
@app.post("/optimize", response_model=OptimizeResponse)
async def optimize_formulation(request: OptimizeRequest):
"""
执行配方优化
通过迭代式 Grid Search 寻找最大化目标器官 Biodistribution 的最优配方。
"""
# 验证模型状态
if state.model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# 验证器官
if request.organ not in AVAILABLE_ORGANS:
raise HTTPException(
status_code=400,
detail=f"Invalid organ: {request.organ}. Available: {AVAILABLE_ORGANS}"
)
# 验证 SMILES
if not request.smiles or len(request.smiles.strip()) == 0:
raise HTTPException(status_code=400, detail="SMILES string cannot be empty")
# 验证 routes
valid_routes = ["intravenous", "intramuscular"]
if request.routes is not None:
for r in request.routes:
if r not in valid_routes:
raise HTTPException(
status_code=400,
detail=f"Invalid route: {r}. Available: {valid_routes}"
)
if len(request.routes) == 0:
raise HTTPException(status_code=400, detail="At least one route must be specified")
logger.info(f"Optimization request: organ={request.organ}, routes={request.routes}, smiles={request.smiles[:50]}...")
# 构建组分范围配置(在 try 块外验证,确保返回 400 而非 500
comp_ranges = None
if request.comp_ranges is not None:
comp_ranges = request.comp_ranges.to_comp_ranges()
# 验证范围是否合理
validation_error = comp_ranges.get_validation_error()
if validation_error:
raise HTTPException(
status_code=400,
detail=f"组分范围配置无效: {validation_error}"
)
# 构建评分权重配置
scoring_weights = None
if request.scoring_weights is not None:
scoring_weights = request.scoring_weights.to_scoring_weights()
try:
results = optimize(
smiles=request.smiles,
organ=request.organ,
model=state.model,
device=state.device,
top_k=request.top_k,
num_seeds=request.num_seeds,
top_per_seed=request.top_per_seed,
step_sizes=request.step_sizes,
wr_step_sizes=request.wr_step_sizes,
comp_ranges=comp_ranges,
routes=request.routes,
scoring_weights=scoring_weights,
batch_size=256,
)
# 用于计算综合评分的权重
from app.optimize import compute_formulation_score, DEFAULT_SCORING_WEIGHTS
actual_scoring_weights = scoring_weights if scoring_weights is not None else DEFAULT_SCORING_WEIGHTS
# 转换结果
formulations = []
for i, f in enumerate(results):
formulations.append(FormulationResult(
rank=i + 1,
target_biodist=f.get_biodist(request.organ),
composite_score=compute_formulation_score(f, request.organ, actual_scoring_weights),
cationic_lipid_to_mrna_ratio=f.cationic_lipid_to_mrna_ratio,
cationic_lipid_mol_ratio=f.cationic_lipid_mol_ratio,
phospholipid_mol_ratio=f.phospholipid_mol_ratio,
cholesterol_mol_ratio=f.cholesterol_mol_ratio,
peg_lipid_mol_ratio=f.peg_lipid_mol_ratio,
helper_lipid=f.helper_lipid,
route=f.route,
all_biodist={
col.replace("Biodistribution_", ""): f.biodist_predictions.get(col, 0.0)
for col in TARGET_BIODIST
},
# 额外预测值
quantified_delivery=f.quantified_delivery,
unnormalized_delivery=f.unnormalized_delivery,
size=f.size,
pdi_class=f.pdi_class,
ee_class=f.ee_class,
toxic_class=f.toxic_class,
))
logger.success(f"Optimization completed: {len(formulations)} formulations")
return OptimizeResponse(
smiles=request.smiles,
target_organ=request.organ,
formulations=formulations,
message=f"Successfully found top {len(formulations)} formulations for {request.organ}",
)
except Exception as e:
logger.error(f"Optimization failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.api:app",
host="0.0.0.0",
port=8000,
reload=True,
)