From 28b181e1940bdacedbc22c6213900e5f582f2651 Mon Sep 17 00:00:00 2001 From: RYDE-WORK Date: Tue, 3 Mar 2026 15:48:59 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=A2=84=E5=A4=84=E7=90=86?= =?UTF-8?q?=E6=97=B6=E8=87=AA=E5=8A=A8=E5=AD=98=E6=A1=A3=E6=80=BB=E8=8D=A7?= =?UTF-8?q?=E5=85=89=E5=BC=BA=E5=BA=A6=E7=9A=84=E5=9D=87=E5=80=BC=E4=B8=8E?= =?UTF-8?q?=E6=96=B9=E5=B7=AE,=E4=BB=A5=E4=BE=BF=E4=BA=8E=E9=A2=84?= =?UTF-8?q?=E6=B5=8B=E6=97=B6=E5=8F=8D=E6=BC=94=E8=8D=A7=E5=85=89=E5=BC=BA?= =?UTF-8?q?=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api.py | 2 ++ app/app.py | 2 ++ app/delivery_zscore_stats.json | 14 ++++++++++++ app/optimize.py | 40 +++++++++++++++++++++++++++++----- scripts/preprocess_internal.py | 26 ++++++++++++++++++++++ 5 files changed, 79 insertions(+), 5 deletions(-) create mode 100644 app/delivery_zscore_stats.json diff --git a/app/api.py b/app/api.py index 099eabb..6fed55a 100644 --- a/app/api.py +++ b/app/api.py @@ -124,6 +124,7 @@ class FormulationResult(BaseModel): all_biodist: Dict[str, float] # 额外预测值 quantified_delivery: Optional[float] = None + unnormalized_delivery: Optional[float] = None # 反推的原始递送值(z-score 逆变换) size: Optional[float] = None pdi_class: Optional[int] = None # PDI 分类 (0: <0.2, 1: 0.2-0.3, 2: 0.3-0.4, 3: >0.4) ee_class: Optional[int] = None # EE 分类 (0: <80%, 1: 80-90%, 2: >90%) @@ -331,6 +332,7 @@ async def optimize_formulation(request: OptimizeRequest): }, # 额外预测值 quantified_delivery=f.quantified_delivery, + unnormalized_delivery=f.unnormalized_delivery, size=f.size, pdi_class=f.pdi_class, ee_class=f.ee_class, diff --git a/app/app.py b/app/app.py index dbac7ba..ec819ea 100644 --- a/app/app.py +++ b/app/app.py @@ -227,6 +227,8 @@ def format_results_dataframe(results: dict, smiles_label: str = None) -> pd.Data # 添加额外预测值 if f.get("quantified_delivery") is not None: row["量化递送"] = f"{f['quantified_delivery']:.4f}" + if f.get("unnormalized_delivery") is not None: + row["总荧光强度"] = f"{f['unnormalized_delivery']:.4f}" if f.get("size") is not None: row["粒径(nm)"] = f"{f['size']:.1f}" if f.get("pdi_class") is not None: diff --git a/app/delivery_zscore_stats.json b/app/delivery_zscore_stats.json new file mode 100644 index 0000000..a26bd5f --- /dev/null +++ b/app/delivery_zscore_stats.json @@ -0,0 +1,14 @@ +{ + "intramuscular": { + "mean": 0.7281303554081238, + "std": 0.7006554090148486, + "qd_min": -1.0387570720282182, + "qd_max": 4.73706835052163 + }, + "intravenous": { + "mean": 0.29940387649347033, + "std": 0.37474351840219583, + "qd_min": -0.7985592911689305, + "qd_max": 4.497814051056962 + } +} \ No newline at end of file diff --git a/app/optimize.py b/app/optimize.py index 6d2d705..5a362ab 100644 --- a/app/optimize.py +++ b/app/optimize.py @@ -8,6 +8,7 @@ """ import itertools +import json from pathlib import Path from typing import Dict, List, Optional, Tuple from dataclasses import dataclass, field @@ -124,11 +125,25 @@ HELPER_LIPID_OPTIONS = ["DOPE", "DSPC"] # Route of administration 选项 ROUTE_OPTIONS = ["intravenous", "intramuscular"] -# quantified_delivery 归一化常量(按给药途径) -DELIVERY_NORM = { - "intravenous": {"min": -0.798559291, "max": 4.497814051056962}, - "intramuscular": {"min": -0.794912427, "max": 10.220042980012716}, -} +# delivery 统计量(由 preprocess_internal.py 生成) +# 包含: mean/std(z-score 逆变换)、qd_min/qd_max(评分归一化) +_DELIVERY_STATS_PATH = Path(__file__).resolve().parent / "delivery_zscore_stats.json" +if _DELIVERY_STATS_PATH.exists(): + with open(_DELIVERY_STATS_PATH) as _f: + DELIVERY_ZSCORE_STATS: Dict[str, Dict[str, float]] = json.load(_f) + logger.info(f"Loaded delivery stats from {_DELIVERY_STATS_PATH}") +else: + DELIVERY_ZSCORE_STATS = {} + logger.warning(f"delivery_zscore_stats.json not found at {_DELIVERY_STATS_PATH}, " + "run 'make preprocess' to generate it") + +# quantified_delivery 归一化常量(从统计量中提取 qd_min/qd_max,用于评分归一化到 [0,1]) +DELIVERY_NORM: Dict[str, Dict[str, float]] = {} +for _route, _stats in DELIVERY_ZSCORE_STATS.items(): + if "qd_min" in _stats and "qd_max" in _stats: + DELIVERY_NORM[_route] = {"min": _stats["qd_min"], "max": _stats["qd_max"]} +if not DELIVERY_NORM: + logger.warning("DELIVERY_NORM is empty — scoring normalization for delivery will be disabled") @dataclass @@ -282,6 +297,7 @@ class Formulation: biodist_predictions: Dict[str, float] = field(default_factory=dict) # 额外预测值 quantified_delivery: Optional[float] = None + unnormalized_delivery: Optional[float] = None # 反推的原始递送值(z-score 逆变换) size: Optional[float] = None pdi_class: Optional[int] = None # PDI 分类 (0-3) ee_class: Optional[int] = None # EE 分类 (0-2) @@ -587,6 +603,16 @@ def predict_all( df["pred_pdi_class"] = pdi_preds df["pred_ee_class"] = ee_preds df["pred_toxic_class"] = toxic_preds + + # 反推 unnormalized_delivery: value = z-score * std + mean + df["pred_unnorm_delivery"] = np.nan + if DELIVERY_ZSCORE_STATS: + for route_name, stats in DELIVERY_ZSCORE_STATS.items(): + mask = df["_route"] == route_name + if mask.any(): + df.loc[mask, "pred_unnorm_delivery"] = ( + delivery_preds[mask.values] * stats["std"] + stats["mean"] + ) return df @@ -645,6 +671,9 @@ def select_top_k( if key not in seen: seen.add(key) + unnorm_val = row.get("pred_unnorm_delivery") + unnorm_delivery = float(unnorm_val) if pd.notna(unnorm_val) else None + formulation = Formulation( cationic_lipid_to_mrna_ratio=row["Cationic_Lipid_to_mRNA_weight_ratio"], cationic_lipid_mol_ratio=row["Cationic_Lipid_Mol_Ratio"], @@ -658,6 +687,7 @@ def select_top_k( }, # 额外预测值 quantified_delivery=row.get("pred_delivery"), + unnormalized_delivery=unnorm_delivery, size=row.get("pred_size"), pdi_class=int(row.get("pred_pdi_class")) if row.get("pred_pdi_class") is not None else None, ee_class=int(row.get("pred_ee_class")) if row.get("pred_ee_class") is not None else None, diff --git a/scripts/preprocess_internal.py b/scripts/preprocess_internal.py index 2e54998..78f2e94 100644 --- a/scripts/preprocess_internal.py +++ b/scripts/preprocess_internal.py @@ -1,5 +1,6 @@ """数据清洗脚本:修正原始数据中的问题""" +import json from pathlib import Path import numpy as np @@ -9,6 +10,7 @@ from loguru import logger from lnp_ml.config import RAW_DATA_DIR, INTERIM_DATA_DIR +APP_DIR = Path(__file__).resolve().parents[1] / "app" app = typer.Typer() @@ -24,6 +26,7 @@ def main( 修正内容: 1. 按给药途径分组进行 z-score 标准化 2. 对 size 列取 log + 3. 将 z-score 的 mean/std 保存到 app/ 供推理时反推 """ logger.info(f"Loading data from {input_path}") df = pd.read_excel(input_path, header=2) @@ -32,11 +35,34 @@ def main( # 分别对肌肉注射组和静脉注射组重新进行 z-score 标准化 logger.info("Z-score normalizing delivery by Route_of_administration...") df["unnormalized_delivery"] = pd.to_numeric(df["unnormalized_delivery"], errors="coerce") + + # 计算并保存 per-route 统计量,用于推理时反推和评分归一化 + zscore_stats = {} + for route, group in df.groupby("Route_of_administration"): + vals = group["unnormalized_delivery"].dropna() + if len(vals) > 1: + zscore_stats[route] = {"mean": float(vals.mean()), "std": float(vals.std())} + logger.info(f" {route}: mean={vals.mean():.6f}, std={vals.std():.6f}, n={len(vals)}") + df["quantified_delivery"] = ( df.groupby("Route_of_administration")["unnormalized_delivery"] .transform(lambda x: (x - x.mean()) / x.std()) ) + # 补充 quantified_delivery 的 per-route min/max(用于评分时归一化到 [0,1]) + for route, group in df.groupby("Route_of_administration"): + qd = group["quantified_delivery"].dropna() + if len(qd) > 0 and route in zscore_stats: + zscore_stats[route]["qd_min"] = float(qd.min()) + zscore_stats[route]["qd_max"] = float(qd.max()) + logger.info(f" {route}: qd_min={qd.min():.6f}, qd_max={qd.max():.6f}") + + stats_path = APP_DIR / "delivery_zscore_stats.json" + stats_path.parent.mkdir(parents=True, exist_ok=True) + with open(stats_path, "w") as f: + json.dump(zscore_stats, f, indent=2) + logger.success(f"Saved delivery stats to {stats_path}") + # 对 size 列取 log logger.info("Log-transforming size column...") df["size"] = pd.to_numeric(df["size"], errors="coerce")