diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..260ed56 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,63 @@ +# LNP-ML Docker Image +# 多阶段构建,支持 API 和 Streamlit 两种服务 + +FROM python:3.8-slim AS base + +# 设置环境变量 +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_DISABLE_PIP_VERSION_CHECK=1 + +WORKDIR /app + +# 安装系统依赖 +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + libxrender1 \ + libxext6 \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# 复制依赖文件 +COPY requirements.txt . + +# 安装 Python 依赖 +RUN pip install --upgrade pip && \ + pip install -r requirements.txt + +# 复制项目代码 +COPY pyproject.toml . +COPY README.md . +COPY LICENSE . +COPY lnp_ml/ ./lnp_ml/ +COPY app/ ./app/ + +# 安装项目包 +RUN pip install -e . + +# 复制模型文件 +COPY models/final/ ./models/final/ + +# ============ API 服务 ============ +FROM base AS api + +EXPOSE 8000 + +ENV MODEL_PATH=/app/models/final/model.pt + +CMD ["uvicorn", "app.api:app", "--host", "0.0.0.0", "--port", "8000"] + +# ============ Streamlit 服务 ============ +FROM base AS streamlit + +EXPOSE 8501 + +# Streamlit 配置 +ENV STREAMLIT_SERVER_PORT=8501 \ + STREAMLIT_SERVER_ADDRESS=0.0.0.0 \ + STREAMLIT_SERVER_HEADLESS=true \ + STREAMLIT_BROWSER_GATHER_USAGE_STATS=false + +CMD ["streamlit", "run", "app/app.py"] + diff --git a/Makefile b/Makefile index 8d7182b..1608bbb 100644 --- a/Makefile +++ b/Makefile @@ -164,6 +164,44 @@ test_cv: requirements tune: requirements $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG) $(DEVICE_FLAG) +# ============ 嵌套 CV + Optuna 调参(StratifiedKFold + 类权重) ============ +# 通用参数: +# SEED: 随机种子 (默认: 42) +# N_TRIALS: Optuna 试验数 (默认: 20) +# EPOCHS_PER_TRIAL: 每个试验的最大 epoch (默认: 30) +# MIN_STRATUM_COUNT: 复合分层标签的最小样本数 (默认: 5) +# OUTPUT_DIR: 输出目录 (根据命令有不同默认值) +# INIT_PRETRAIN: 预训练权重路径 (默认: models/pretrain_delivery.pt) + +SEED_FLAG = $(if $(SEED),--seed $(SEED),) +N_TRIALS_FLAG = $(if $(N_TRIALS),--n-trials $(N_TRIALS),) +EPOCHS_PER_TRIAL_FLAG = $(if $(EPOCHS_PER_TRIAL),--epochs-per-trial $(EPOCHS_PER_TRIAL),) +MIN_STRATUM_FLAG = $(if $(MIN_STRATUM_COUNT),--min-stratum-count $(MIN_STRATUM_COUNT),) +OUTPUT_DIR_FLAG = $(if $(OUTPUT_DIR),--output-dir $(OUTPUT_DIR),) +USE_SWA_FLAG = $(if $(USE_SWA),--use-swa,) +# 默认使用预训练权重,设置 NO_PRETRAIN=1 可禁用 +INIT_PRETRAIN_FLAG = $(if $(NO_PRETRAIN),,--init-from-pretrain $(or $(INIT_PRETRAIN),models/pretrain_delivery.pt)) + +## Nested CV with Optuna: outer 5-fold (test) + inner 3-fold (tune) +## 用于模型评估:外层 5-fold 产生无偏性能估计,内层 3-fold 做超参搜索 +## 默认加载 models/pretrain_delivery.pt 预训练权重,使用 NO_PRETRAIN=1 禁用 +## 使用示例: make nested_cv_tune DEVICE=cuda N_TRIALS=30 +.PHONY: nested_cv_tune +nested_cv_tune: requirements + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.nested_cv_optuna \ + $(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \ + $(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) + +## Final training with Optuna: 3-fold CV tune + full data train +## 用于最终模型训练:3-fold 调参后用全量数据训练(无 early-stop) +## 默认加载 models/pretrain_delivery.pt 预训练权重,使用 NO_PRETRAIN=1 禁用 +## 使用示例: make final_optuna DEVICE=cuda N_TRIALS=30 USE_SWA=1 +.PHONY: final_optuna +final_optuna: requirements + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.final_train_optuna_cv \ + $(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \ + $(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG) + ## Run predictions .PHONY: predict predict: requirements @@ -200,6 +238,48 @@ serve: @echo "然后访问: http://localhost:8501" +################################################################################# +# DOCKER COMMANDS # +################################################################################# + +## Build Docker images +.PHONY: docker-build +docker-build: + docker compose build + +## Start all services with Docker Compose +.PHONY: docker-up +docker-up: + docker compose up -d + +## Stop all Docker services +.PHONY: docker-down +docker-down: + docker compose down + +## View Docker logs +.PHONY: docker-logs +docker-logs: + docker compose logs -f + +## Build and start all services +.PHONY: docker-serve +docker-serve: docker-build docker-up + @echo "" + @echo "🚀 服务已启动!" + @echo " - API: http://localhost:8000" + @echo " - Web 应用: http://localhost:8501" + @echo "" + @echo "查看日志: make docker-logs" + @echo "停止服务: make docker-down" + +## Clean Docker resources (images, volumes, etc.) +.PHONY: docker-clean +docker-clean: + docker compose down -v --rmi local + docker system prune -f + + ################################################################################# # Self Documenting Commands # ################################################################################# diff --git a/app/SCORE.md b/app/SCORE.md new file mode 100644 index 0000000..5e3173d --- /dev/null +++ b/app/SCORE.md @@ -0,0 +1,15 @@ +## regression +biodistribution(selected organ only): score = y * weight, where weight=0.3 +quantified_delivery: score = (y-min)/(max-min)*weight, where weight=0.25, (min=-0.798559291, max=4.497814051056962) when route_of_administration=intravenous, (min=-0.794912427, max=10.220042980012716) when route_of_administration=intramuscular +size: score = 0 * weight if y<60, 1 * weight if 60<=y<=150, 0 * weight if y>150, where weight=0.05 + +## classification +encapsulation_efficiency_0: score = weight, where weight=0 +encapsulation_efficiency_1: score = weight, where weight=0.02 +encapsulation_efficiency_2: score = weight, where weight=0.08 +pdi_0: score = weight, where weight=0.08 +pdi_1: score = weight, where weight=0.02 +pdi_2: score = weight, where weight=0 +pdi_3: score = weight, where weight=0 +toxicity_0: score=weight, where weight=0.2 +toxicity_1: score=weight, where weight=0 diff --git a/app/api.py b/app/api.py index 61af646..04b91c8 100644 --- a/app/api.py +++ b/app/api.py @@ -23,23 +23,87 @@ from app.optimize import ( format_results, AVAILABLE_ORGANS, TARGET_BIODIST, + CompRanges, + ScoringWeights, ) # ============ Pydantic Models ============ +class CompRangesRequest(BaseModel): + """组分范围配置""" + weight_ratio_min: float = Field(default=0.05, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最小值") + weight_ratio_max: float = Field(default=0.30, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最大值") + cationic_mol_min: float = Field(default=0.05, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最小值") + cationic_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最大值") + phospholipid_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="磷脂 mol 比例最小值") + phospholipid_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="磷脂 mol 比例最大值") + cholesterol_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="胆固醇 mol 比例最小值") + cholesterol_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="胆固醇 mol 比例最大值") + peg_mol_min: float = Field(default=0.00, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最小值") + peg_mol_max: float = Field(default=0.05, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最大值") + + def to_comp_ranges(self) -> CompRanges: + """转换为 CompRanges 对象""" + return CompRanges( + weight_ratio_min=self.weight_ratio_min, + weight_ratio_max=self.weight_ratio_max, + cationic_mol_min=self.cationic_mol_min, + cationic_mol_max=self.cationic_mol_max, + phospholipid_mol_min=self.phospholipid_mol_min, + phospholipid_mol_max=self.phospholipid_mol_max, + cholesterol_mol_min=self.cholesterol_mol_min, + cholesterol_mol_max=self.cholesterol_mol_max, + peg_mol_min=self.peg_mol_min, + peg_mol_max=self.peg_mol_max, + ) + + +class ScoringWeightsRequest(BaseModel): + """评分权重配置""" + biodist_weight: float = Field(default=1.0, ge=0.0, description="目标器官分布权重") + delivery_weight: float = Field(default=0.0, ge=0.0, description="量化递送权重") + size_weight: float = Field(default=0.0, ge=0.0, description="粒径权重 (80-150nm)") + ee_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0], description="EE 分类权重 [class0, class1, class2]") + pdi_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0, 0.0], description="PDI 分类权重 [class0, class1, class2, class3]") + toxic_class_weights: List[float] = Field(default=[0.0, 0.0], description="毒性分类权重 [无毒, 有毒]") + + def to_scoring_weights(self) -> ScoringWeights: + """转换为 ScoringWeights 对象""" + return ScoringWeights( + biodist_weight=self.biodist_weight, + delivery_weight=self.delivery_weight, + size_weight=self.size_weight, + ee_class_weights=self.ee_class_weights, + pdi_class_weights=self.pdi_class_weights, + toxic_class_weights=self.toxic_class_weights, + ) + + class OptimizeRequest(BaseModel): """优化请求""" smiles: str = Field(..., description="Cationic lipid SMILES string") organ: str = Field(..., description="Target organ for optimization") - top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations") + top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations to return") + num_seeds: Optional[int] = Field(default=None, ge=1, le=500, description="Number of seed points from first iteration (default: top_k * 5)") + top_per_seed: int = Field(default=1, ge=1, le=10, description="Number of local best to keep per seed in refinement") + step_sizes: Optional[List[float]] = Field(default=None, description="Step sizes for each iteration (default: [0.10, 0.02, 0.01])") + comp_ranges: Optional[CompRangesRequest] = Field(default=None, description="组分范围配置(默认使用标准范围)") + routes: Optional[List[str]] = Field(default=None, description="给药途径列表 (default: ['intravenous', 'intramuscular'])") + scoring_weights: Optional[ScoringWeightsRequest] = Field(default=None, description="评分权重配置(默认仅按 biodist 排序)") class Config: json_schema_extra = { "example": { "smiles": "CC(C)NCCNC(C)C", "organ": "liver", - "top_k": 20 + "top_k": 20, + "num_seeds": None, + "top_per_seed": 1, + "step_sizes": None, + "comp_ranges": None, + "routes": None, + "scoring_weights": None } } @@ -48,6 +112,7 @@ class FormulationResult(BaseModel): """单个配方结果""" rank: int target_biodist: float + composite_score: Optional[float] = None # 综合评分 cationic_lipid_to_mrna_ratio: float cationic_lipid_mol_ratio: float phospholipid_mol_ratio: float @@ -56,6 +121,12 @@ class FormulationResult(BaseModel): helper_lipid: str route: str all_biodist: Dict[str, float] + # 额外预测值 + quantified_delivery: Optional[float] = None + size: Optional[float] = None + pdi_class: Optional[int] = None # PDI 分类 (0: <0.2, 1: 0.2-0.3, 2: 0.3-0.4, 3: >0.4) + ee_class: Optional[int] = None # EE 分类 (0: <80%, 1: 80-90%, 2: >90%) + toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒) class OptimizeResponse(BaseModel): @@ -187,25 +258,65 @@ async def optimize_formulation(request: OptimizeRequest): if not request.smiles or len(request.smiles.strip()) == 0: raise HTTPException(status_code=400, detail="SMILES string cannot be empty") - logger.info(f"Optimization request: organ={request.organ}, smiles={request.smiles[:50]}...") + # 验证 routes + valid_routes = ["intravenous", "intramuscular"] + if request.routes is not None: + for r in request.routes: + if r not in valid_routes: + raise HTTPException( + status_code=400, + detail=f"Invalid route: {r}. Available: {valid_routes}" + ) + if len(request.routes) == 0: + raise HTTPException(status_code=400, detail="At least one route must be specified") + + logger.info(f"Optimization request: organ={request.organ}, routes={request.routes}, smiles={request.smiles[:50]}...") + + # 构建组分范围配置(在 try 块外验证,确保返回 400 而非 500) + comp_ranges = None + if request.comp_ranges is not None: + comp_ranges = request.comp_ranges.to_comp_ranges() + # 验证范围是否合理 + validation_error = comp_ranges.get_validation_error() + if validation_error: + raise HTTPException( + status_code=400, + detail=f"组分范围配置无效: {validation_error}" + ) + + # 构建评分权重配置 + scoring_weights = None + if request.scoring_weights is not None: + scoring_weights = request.scoring_weights.to_scoring_weights() try: - # 执行优化 + # 执行优化(层级搜索策略) results = optimize( smiles=request.smiles, organ=request.organ, model=state.model, device=state.device, top_k=request.top_k, + num_seeds=request.num_seeds, + top_per_seed=request.top_per_seed, + step_sizes=request.step_sizes, + comp_ranges=comp_ranges, + routes=request.routes, + scoring_weights=scoring_weights, batch_size=256, ) + # 用于计算综合评分的权重 + from app.optimize import compute_formulation_score, DEFAULT_SCORING_WEIGHTS + actual_scoring_weights = scoring_weights if scoring_weights is not None else DEFAULT_SCORING_WEIGHTS + # 转换结果 formulations = [] for i, f in enumerate(results): formulations.append(FormulationResult( rank=i + 1, target_biodist=f.get_biodist(request.organ), + composite_score=compute_formulation_score(f, request.organ, actual_scoring_weights), cationic_lipid_to_mrna_ratio=f.cationic_lipid_to_mrna_ratio, cationic_lipid_mol_ratio=f.cationic_lipid_mol_ratio, phospholipid_mol_ratio=f.phospholipid_mol_ratio, @@ -217,6 +328,12 @@ async def optimize_formulation(request: OptimizeRequest): col.replace("Biodistribution_", ""): f.biodist_predictions.get(col, 0.0) for col in TARGET_BIODIST }, + # 额外预测值 + quantified_delivery=f.quantified_delivery, + size=f.size, + pdi_class=f.pdi_class, + ee_class=f.ee_class, + toxic_class=f.toxic_class, )) logger.success(f"Optimization completed: {len(formulations)} formulations") diff --git a/app/app.py b/app/app.py index 6b3a093..2dddec2 100644 --- a/app/app.py +++ b/app/app.py @@ -3,9 +3,13 @@ Streamlit 配方优化交互界面 启动应用: streamlit run app/app.py + +Docker 环境变量: + API_URL: API 服务地址 (默认: http://localhost:8000) """ import io +import os from datetime import datetime import httpx @@ -14,7 +18,8 @@ import streamlit as st # ============ 配置 ============ -API_URL = "http://localhost:8000" +# 从环境变量读取 API 地址,支持 Docker 环境 +API_URL = os.environ.get("API_URL", "http://localhost:8000") AVAILABLE_ORGANS = [ "liver", @@ -27,13 +32,23 @@ AVAILABLE_ORGANS = [ ] ORGAN_LABELS = { - "liver": "🫀 肝脏 (Liver)", - "spleen": "🟣 脾脏 (Spleen)", - "lung": "🫁 肺 (Lung)", - "heart": "❤️ 心脏 (Heart)", - "kidney": "🫘 肾脏 (Kidney)", - "muscle": "💪 肌肉 (Muscle)", - "lymph_nodes": "🔵 淋巴结 (Lymph Nodes)", + "liver": "肝脏 (Liver)", + "spleen": "脾脏 (Spleen)", + "lung": "肺 (Lung)", + "heart": "心脏 (Heart)", + "kidney": "肾脏 (Kidney)", + "muscle": "肌肉 (Muscle)", + "lymph_nodes": "淋巴结 (Lymph Nodes)", +} + +AVAILABLE_ROUTES = [ + "intravenous", + "intramuscular", +] + +ROUTE_LABELS = { + "intravenous": "静脉注射 (Intravenous)", + "intramuscular": "肌肉注射 (Intramuscular)", } # ============ 页面配置 ============ @@ -122,43 +137,107 @@ def check_api_status() -> bool: return False -def call_optimize_api(smiles: str, organ: str, top_k: int = 20) -> dict: +def call_optimize_api( + smiles: str, + organ: str, + top_k: int = 20, + num_seeds: int = None, + top_per_seed: int = 1, + step_sizes: list = None, + comp_ranges: dict = None, + routes: list = None, + scoring_weights: dict = None, +) -> dict: """调用优化 API""" - with httpx.Client(timeout=300) as client: # 5 分钟超时 + payload = { + "smiles": smiles, + "organ": organ, + "top_k": top_k, + "num_seeds": num_seeds, + "top_per_seed": top_per_seed, + "step_sizes": step_sizes, + "comp_ranges": comp_ranges, + "routes": routes, + "scoring_weights": scoring_weights, + } + + with httpx.Client(timeout=600) as client: # 10 分钟超时(自定义参数可能需要更长时间) response = client.post( f"{API_URL}/optimize", - json={ - "smiles": smiles, - "organ": organ, - "top_k": top_k, - }, + json=payload, ) response.raise_for_status() return response.json() -def format_results_dataframe(results: dict) -> pd.DataFrame: +# PDI 分类标签 +PDI_CLASS_LABELS = { + 0: "<0.2 (优)", + 1: "0.2-0.3 (良)", + 2: "0.3-0.4 (中)", + 3: ">0.4 (差)", +} + +# EE 分类标签 +EE_CLASS_LABELS = { + 0: "<50% (低)", + 1: "50-80% (中)", + 2: ">80% (高)", +} + +# 毒性分类标签 +TOXIC_CLASS_LABELS = { + 0: "无毒 ✓", + 1: "有毒 ⚠", +} + + +def format_results_dataframe(results: dict, smiles_label: str = None) -> pd.DataFrame: """将 API 结果转换为 DataFrame""" formulations = results["formulations"] target_organ = results["target_organ"] rows = [] for f in formulations: - row = { + row = {} + + # 如果有 SMILES 标签,添加到首列 + if smiles_label: + row["SMILES"] = smiles_label + + row.update({ "排名": f["rank"], - f"Biodist_{target_organ}": f"{f['target_biodist']:.4f}", - "阳离子脂质/mRNA": f["cationic_lipid_to_mrna_ratio"], - "阳离子脂质(mol)": f["cationic_lipid_mol_ratio"], - "磷脂(mol)": f["phospholipid_mol_ratio"], - "胆固醇(mol)": f["cholesterol_mol_ratio"], - "PEG脂质(mol)": f["peg_lipid_mol_ratio"], + }) + # 如果有综合评分,显示在排名后面 + if f.get("composite_score") is not None: + row["综合评分"] = f"{f['composite_score']:.4f}" + row.update({ + f"{target_organ}分布": f"{f['target_biodist']*100:.8f}%", + "阳离子脂质/mRNA比例": f["cationic_lipid_to_mrna_ratio"], + "阳离子脂质(mol)比例": f["cationic_lipid_mol_ratio"], + "磷脂(mol)比例": f["phospholipid_mol_ratio"], + "胆固醇(mol)比例": f["cholesterol_mol_ratio"], + "PEG脂质(mol)比例": f["peg_lipid_mol_ratio"], "辅助脂质": f["helper_lipid"], "给药途径": f["route"], - } + }) + + # 添加额外预测值 + if f.get("quantified_delivery") is not None: + row["量化递送"] = f"{f['quantified_delivery']:.4f}" + if f.get("size") is not None: + row["粒径(nm)"] = f"{f['size']:.1f}" + if f.get("pdi_class") is not None: + row["PDI"] = PDI_CLASS_LABELS.get(f["pdi_class"], str(f["pdi_class"])) + if f.get("ee_class") is not None: + row["包封率"] = EE_CLASS_LABELS.get(f["ee_class"], str(f["ee_class"])) + if f.get("toxic_class") is not None: + row["毒性"] = TOXIC_CLASS_LABELS.get(f["toxic_class"], str(f["toxic_class"])) + # 添加其他器官的 biodist for organ, value in f["all_biodist"].items(): if organ != target_organ: - row[f"Biodist_{organ}"] = f"{value:.4f}" + row[f"{organ}分布"] = f"{value*100:.2f}%" rows.append(row) return pd.DataFrame(rows) @@ -184,7 +263,7 @@ def main(): # ========== 侧边栏 ========== with st.sidebar: - st.header("⚙️ 参数设置") + # st.header("⚙️ 参数设置") # API 状态 if api_online: @@ -193,7 +272,7 @@ def main(): st.error("🔴 API 服务离线") st.info("请先启动 API 服务:\n```\nuvicorn app.api:app --port 8000\n```") - st.divider() + # st.divider() # SMILES 输入 st.subheader("🔬 分子结构") @@ -201,23 +280,23 @@ def main(): "输入阳离子脂质 SMILES", value="", height=100, - placeholder="例如: CC(C)NCCNC(C)C", - help="输入阳离子脂质的 SMILES 字符串", + placeholder="例如: CC(C)NCCNC(C)C\n多条SMILES用英文逗号分隔: SMI1,SMI2,SMI3", + help="输入阳离子脂质的 SMILES 字符串。支持多条 SMILES,用英文逗号 (,) 分隔", ) # 示例 SMILES - with st.expander("📋 示例 SMILES"): - example_smiles = { - "DLin-MC3-DMA": "CC(C)=CCCC(C)=CCCC(C)=CCN(C)CCCCCCCCOC(=O)CCCCCCC/C=C\\CCCCCCCC", - "简单胺": "CC(C)NCCNC(C)C", - "长链胺": "CCCCCCCCCCCCNCCNCCCCCCCCCCCC", - } - for name, smi in example_smiles.items(): - if st.button(f"使用 {name}", key=f"example_{name}"): - st.session_state["smiles_input"] = smi - st.rerun() + # with st.expander("📋 示例 SMILES"): + # example_smiles = { + # "DLin-MC3-DMA": "CC(C)=CCCC(C)=CCCC(C)=CCN(C)CCCCCCCCOC(=O)CCCCCCC/C=C\\CCCCCCCC", + # "简单胺": "CC(C)NCCNC(C)C", + # "长链胺": "CCCCCCCCCCCCNCCNCCCCCCCCCCCC", + # } + # for name, smi in example_smiles.items(): + # if st.button(f"使用 {name}", key=f"example_{name}"): + # st.session_state["smiles_input"] = smi + # st.rerun() - st.divider() + # st.divider() # 目标器官选择 st.subheader("🎯 目标器官") @@ -228,17 +307,226 @@ def main(): index=0, ) - st.divider() + # 给药途径选择 + st.subheader("💉 给药途径") + selected_routes = st.multiselect( + "选择给药途径", + options=AVAILABLE_ROUTES, + default=AVAILABLE_ROUTES, + format_func=lambda x: ROUTE_LABELS.get(x, x), + help="选择要搜索的给药途径,可多选。至少选择一种。", + ) + if not selected_routes: + st.warning("⚠️ 请至少选择一种给药途径") # 高级选项 with st.expander("🔧 高级选项"): + st.markdown("**输出设置**") top_k = st.slider( - "返回配方数量", + "返回配方数量 (top_k)", min_value=5, - max_value=50, + max_value=100, value=20, step=5, + help="最终返回的最优配方数量", ) + + st.markdown("**搜索策略**") + num_seeds = st.slider( + "种子点数量 (num_seeds)", + min_value=10, + max_value=200, + value=top_k * 5, + step=10, + help="第一轮迭代后保留的种子点数量,更多种子点意味着更广泛的搜索", + ) + + top_per_seed = st.slider( + "每个种子的局部最优数 (top_per_seed)", + min_value=1, + max_value=5, + value=1, + step=1, + help="后续迭代中,每个种子点邻域保留的局部最优数量", + ) + + st.markdown("**迭代步长与轮数**") + use_custom_steps = st.checkbox( + "自定义迭代步长", + value=False, + help="默认步长为 [0.10, 0.02, 0.01],共3轮逐步精细化搜索。将某轮步长设为0可减少迭代轮数。", + ) + + if use_custom_steps: + col1, col2, col3 = st.columns(3) + with col1: + step1 = st.number_input( + "第1轮步长", + min_value=0.01, max_value=0.20, value=0.10, + step=0.01, format="%.2f", + help="第1轮为全局粗搜索,步长必须大于0", + ) + with col2: + step2 = st.number_input( + "第2轮步长", + min_value=0.00, max_value=0.10, value=0.02, + step=0.01, format="%.2f", + help="设为0则只进行1轮搜索", + ) + with col3: + step3 = st.number_input( + "第3轮步长", + min_value=0.00, max_value=0.05, value=0.01, + step=0.01, format="%.2f", + help="设为0则只进行2轮搜索", + ) + + # 根据步长值构建实际的 step_sizes 列表 + # step2 为 0 → 只保留 [step1](1轮) + # step3 为 0 → 只保留 [step1, step2](2轮) + # 都不为 0 → [step1, step2, step3](3轮) + if step2 == 0.0: + step_sizes = [step1] + elif step3 == 0.0: + step_sizes = [step1, step2] + else: + step_sizes = [step1, step2, step3] + + # 显示实际迭代轮数提示 + st.caption(f"📌 实际迭代轮数: {len(step_sizes)} 轮,步长: {step_sizes}") + else: + step_sizes = None # 使用默认值 + + st.markdown("**组分范围限制**") + use_custom_ranges = st.checkbox( + "自定义组分取值范围", + value=False, + help="限制各组分的取值范围(mol 比例加起来仍为 100%)", + ) + + if use_custom_ranges: + st.caption("阳离子脂质/mRNA 重量比") + col1, col2 = st.columns(2) + with col1: + weight_ratio_min = st.number_input("最小", min_value=0.01, max_value=0.50, value=0.05, step=0.01, format="%.2f", key="wr_min") + with col2: + weight_ratio_max = st.number_input("最大", min_value=0.01, max_value=0.50, value=0.30, step=0.01, format="%.2f", key="wr_max") + + st.caption("阳离子脂质 mol 比例") + col1, col2 = st.columns(2) + with col1: + cationic_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.05, step=0.05, format="%.2f", key="cat_min") + with col2: + cationic_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="cat_max") + + st.caption("磷脂 mol 比例") + col1, col2 = st.columns(2) + with col1: + phospholipid_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="phos_min") + with col2: + phospholipid_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="phos_max") + + st.caption("胆固醇 mol 比例") + col1, col2 = st.columns(2) + with col1: + cholesterol_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="chol_min") + with col2: + cholesterol_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="chol_max") + + st.caption("PEG 脂质 mol 比例") + col1, col2 = st.columns(2) + with col1: + peg_mol_min = st.number_input("最小", min_value=0.00, max_value=0.20, value=0.00, step=0.01, format="%.2f", key="peg_min") + with col2: + peg_mol_max = st.number_input("最大", min_value=0.00, max_value=0.20, value=0.05, step=0.01, format="%.2f", key="peg_max") + + comp_ranges = { + "weight_ratio_min": weight_ratio_min, + "weight_ratio_max": weight_ratio_max, + "cationic_mol_min": cationic_mol_min, + "cationic_mol_max": cationic_mol_max, + "phospholipid_mol_min": phospholipid_mol_min, + "phospholipid_mol_max": phospholipid_mol_max, + "cholesterol_mol_min": cholesterol_mol_min, + "cholesterol_mol_max": cholesterol_mol_max, + "peg_mol_min": peg_mol_min, + "peg_mol_max": peg_mol_max, + } + + # 简单验证 + min_sum = cationic_mol_min + phospholipid_mol_min + cholesterol_mol_min + peg_mol_min + max_sum = cationic_mol_max + phospholipid_mol_max + cholesterol_mol_max + peg_mol_max + if min_sum > 1.0 or max_sum < 1.0: + st.warning("⚠️ 当前范围设置可能无法生成有效配方(mol 比例需加起来为 100%)") + else: + comp_ranges = None # 使用默认值 + + st.markdown("**评分/排序权重**") + use_custom_scoring = st.checkbox( + "自定义评分权重", + value=False, + help="默认仅按目标器官分布排序。开启后可自定义多目标加权评分,总分 = 各项score之和。", + ) + + if use_custom_scoring: + st.caption("**回归任务权重**") + + sw_biodist = st.number_input( + "器官分布 (Biodistribution)", + min_value=0.00, max_value=10.00, value=0.30, + step=0.05, format="%.2f", key="sw_biodist", + help="score = biodist_value × weight", + ) + sw_delivery = st.number_input( + "量化递送 (Quantified Delivery)", + min_value=0.00, max_value=10.00, value=0.25, + step=0.05, format="%.2f", key="sw_delivery", + help="score = normalize(delivery, route) × weight", + ) + sw_size = st.number_input( + "粒径 (Size, 80-150nm)", + min_value=0.00, max_value=10.00, value=0.05, + step=0.05, format="%.2f", key="sw_size", + help="score = (1 if 60≤size≤150 else 0) × weight", + ) + + st.caption("**包封率 (EE) 分类权重**") + col1, col2, col3 = st.columns(3) + with col1: + sw_ee0 = st.number_input("<50% (低)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_ee0") + with col2: + sw_ee1 = st.number_input("50-80% (中)", min_value=0.00, max_value=1.00, value=0.02, step=0.01, format="%.2f", key="sw_ee1") + with col3: + sw_ee2 = st.number_input(">80% (高)", min_value=0.00, max_value=1.00, value=0.08, step=0.01, format="%.2f", key="sw_ee2") + + st.caption("**PDI 分类权重**") + col1, col2, col3, col4 = st.columns(4) + with col1: + sw_pdi0 = st.number_input("<0.2 (优)", min_value=0.00, max_value=1.00, value=0.08, step=0.01, format="%.2f", key="sw_pdi0") + with col2: + sw_pdi1 = st.number_input("0.2-0.3 (良)", min_value=0.00, max_value=1.00, value=0.02, step=0.01, format="%.2f", key="sw_pdi1") + with col3: + sw_pdi2 = st.number_input("0.3-0.4 (中)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_pdi2") + with col4: + sw_pdi3 = st.number_input(">0.4 (差)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_pdi3") + + st.caption("**毒性分类权重**") + col1, col2 = st.columns(2) + with col1: + sw_toxic0 = st.number_input("无毒", min_value=0.00, max_value=1.00, value=0.20, step=0.05, format="%.2f", key="sw_toxic0") + with col2: + sw_toxic1 = st.number_input("有毒", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="sw_toxic1") + + scoring_weights = { + "biodist_weight": sw_biodist, + "delivery_weight": sw_delivery, + "size_weight": sw_size, + "ee_class_weights": [sw_ee0, sw_ee1, sw_ee2], + "pdi_class_weights": [sw_pdi0, sw_pdi1, sw_pdi2, sw_pdi3], + "toxic_class_weights": [sw_toxic0, sw_toxic1], + } + else: + scoring_weights = None # 使用默认值(仅按 biodist 排序) st.divider() @@ -247,7 +535,7 @@ def main(): "🚀 开始配方优选", type="primary", use_container_width=True, - disabled=not api_online or not smiles_input.strip(), + disabled=not api_online or not smiles_input.strip() or not selected_routes, ) # ========== 主内容区 ========== @@ -260,49 +548,125 @@ def main(): # 执行优化 if optimize_button and smiles_input.strip(): - with st.spinner("🔄 正在优化配方,请稍候..."): - try: - results = call_optimize_api( - smiles=smiles_input.strip(), - organ=selected_organ, - top_k=top_k, - ) - st.session_state["results"] = results - st.session_state["results_df"] = format_results_dataframe(results) - st.session_state["smiles_used"] = smiles_input.strip() + # 解析多条 SMILES(用逗号分隔) + smiles_list = [s.strip() for s in smiles_input.split(",") if s.strip()] + + if not smiles_list: + st.error("❌ 请输入有效的 SMILES 字符串") + else: + is_multi_smiles = len(smiles_list) > 1 + all_results = [] + all_dfs = [] + errors = [] + + # 进度条 + progress_bar = st.progress(0) + status_text = st.empty() + + for idx, smiles in enumerate(smiles_list): + status_text.text(f"🔄 正在优化 SMILES {idx + 1}/{len(smiles_list)}...") + progress_bar.progress((idx) / len(smiles_list)) + + try: + results = call_optimize_api( + smiles=smiles, + organ=selected_organ, + top_k=top_k, + num_seeds=num_seeds, + top_per_seed=top_per_seed, + step_sizes=step_sizes, + comp_ranges=comp_ranges, + routes=selected_routes, + scoring_weights=scoring_weights, + ) + all_results.append({"smiles": smiles, "results": results}) + + # 为多 SMILES 模式添加 SMILES 标签 + smiles_label = smiles[:30] + "..." if len(smiles) > 30 else smiles + df = format_results_dataframe(results, smiles_label if is_multi_smiles else None) + all_dfs.append(df) + + except httpx.HTTPStatusError as e: + try: + error_detail = e.response.json().get("detail", str(e)) + except: + error_detail = str(e) + errors.append(f"SMILES {idx + 1}: {error_detail}") + except httpx.RequestError as e: + errors.append(f"SMILES {idx + 1}: API 连接失败 - {e}") + except Exception as e: + errors.append(f"SMILES {idx + 1}: {e}") + + progress_bar.progress(1.0) + status_text.empty() + progress_bar.empty() + + # 显示错误 + for err in errors: + st.error(f"❌ {err}") + + # 保存结果 + if all_results: + st.session_state["results"] = all_results[0]["results"] if len(all_results) == 1 else all_results + st.session_state["results_df"] = pd.concat(all_dfs, ignore_index=True) if all_dfs else None + st.session_state["smiles_used"] = smiles_list st.session_state["organ_used"] = selected_organ - st.success("✅ 优化完成!") - except httpx.RequestError as e: - st.error(f"❌ API 请求失败: {e}") - except Exception as e: - st.error(f"❌ 发生错误: {e}") + st.session_state["is_multi_smiles"] = is_multi_smiles + st.success(f"✅ 优化完成!成功处理 {len(all_results)}/{len(smiles_list)} 条 SMILES") # 显示结果 - if st.session_state["results"] is not None: + if st.session_state["results"] is not None and st.session_state["results_df"] is not None: results = st.session_state["results"] df = st.session_state["results_df"] + is_multi_smiles = st.session_state.get("is_multi_smiles", False) # 结果概览 - col1, col2, col3 = st.columns(3) - - with col1: - st.metric( - "目标器官", - ORGAN_LABELS.get(results["target_organ"], results["target_organ"]).split(" ")[0], - ) - - with col2: - best_score = results["formulations"][0]["target_biodist"] - st.metric( - "最优 Biodistribution", - f"{best_score:.4f}", - ) - - with col3: - st.metric( - "优选配方数", - len(results["formulations"]), - ) + if is_multi_smiles: + # 多 SMILES 模式 + col1, col2, col3 = st.columns(3) + + with col1: + # 获取 target_organ(从第一个结果) + first_result = results[0]["results"] if isinstance(results, list) else results + target_organ = first_result["target_organ"] + st.metric( + "目标器官", + ORGAN_LABELS.get(target_organ, target_organ).split(" ")[0], + ) + + with col2: + st.metric( + "SMILES 数量", + len(results) if isinstance(results, list) else 1, + ) + + with col3: + st.metric( + "总配方数", + len(df), + ) + else: + # 单 SMILES 模式 + col1, col2, col3 = st.columns(3) + + with col1: + st.metric( + "目标器官", + ORGAN_LABELS.get(results["target_organ"], results["target_organ"]).split(" ")[0], + ) + + with col2: + best_score = results["formulations"][0]["target_biodist"] + st.metric( + "最优分布", + f"{best_score*100:.2f}%", + ) + + with col3: + st.metric( + "优选配方数", + len(results["formulations"]), + ) st.divider() @@ -312,15 +676,26 @@ def main(): # 导出按钮行 col_export, col_spacer = st.columns([1, 4]) with col_export: + smiles_used = st.session_state.get("smiles_used", "") + if isinstance(smiles_used, list): + smiles_used = ",".join(smiles_used) + csv_content = create_export_csv( df, - st.session_state.get("smiles_used", ""), + smiles_used, st.session_state.get("organ_used", ""), ) + + # 获取 target_organ + if is_multi_smiles: + target_organ = results[0]["results"]["target_organ"] if isinstance(results, list) else results["target_organ"] + else: + target_organ = results["target_organ"] + st.download_button( label="📥 导出 CSV", data=csv_content, - file_name=f"lnp_optimization_{results['target_organ']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", + file_name=f"lnp_optimization_{target_organ}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", mime="text/csv", ) @@ -333,61 +708,61 @@ def main(): ) # 详细信息 - with st.expander("🔍 查看最优配方详情"): - best = results["formulations"][0] + # with st.expander("🔍 查看最优配方详情"): + # best = results["formulations"][0] - col1, col2 = st.columns(2) + # col1, col2 = st.columns(2) - with col1: - st.markdown("**配方参数**") - st.json({ - "阳离子脂质/mRNA 比例": best["cationic_lipid_to_mrna_ratio"], - "阳离子脂质 (mol%)": best["cationic_lipid_mol_ratio"], - "磷脂 (mol%)": best["phospholipid_mol_ratio"], - "胆固醇 (mol%)": best["cholesterol_mol_ratio"], - "PEG 脂质 (mol%)": best["peg_lipid_mol_ratio"], - "辅助脂质": best["helper_lipid"], - "给药途径": best["route"], - }) + # with col1: + # st.markdown("**配方参数**") + # st.json({ + # "阳离子脂质/mRNA 比例": best["cationic_lipid_to_mrna_ratio"], + # "阳离子脂质 (mol%)": best["cationic_lipid_mol_ratio"], + # "磷脂 (mol%)": best["phospholipid_mol_ratio"], + # "胆固醇 (mol%)": best["cholesterol_mol_ratio"], + # "PEG 脂质 (mol%)": best["peg_lipid_mol_ratio"], + # "辅助脂质": best["helper_lipid"], + # "给药途径": best["route"], + # }) - with col2: - st.markdown("**各器官 Biodistribution 预测**") - biodist_df = pd.DataFrame([ - {"器官": ORGAN_LABELS.get(k, k), "Biodistribution": f"{v:.4f}"} - for k, v in best["all_biodist"].items() - ]) - st.dataframe(biodist_df, hide_index=True, use_container_width=True) + # with col2: + # st.markdown("**各器官 Biodistribution 预测**") + # biodist_df = pd.DataFrame([ + # {"器官": ORGAN_LABELS.get(k, k), "Biodistribution": f"{v:.4f}"} + # for k, v in best["all_biodist"].items() + # ]) + # st.dataframe(biodist_df, hide_index=True, use_container_width=True) else: # 欢迎信息 st.info("👈 请在左侧输入 SMILES 并选择目标器官,然后点击「开始配方优选」") # 使用说明 - with st.expander("📖 使用说明"): - st.markdown(""" - ### 如何使用 + # with st.expander("📖 使用说明"): + # st.markdown(""" + # ### 如何使用 - 1. **输入 SMILES**: 在左侧输入框中输入阳离子脂质的 SMILES 字符串 - 2. **选择目标器官**: 选择您希望优化的器官靶向 - 3. **点击优选**: 系统将自动搜索最优配方组合 - 4. **查看结果**: 右侧将显示 Top-20 优选配方 - 5. **导出数据**: 点击导出按钮将结果保存为 CSV 文件 + # 1. **输入 SMILES**: 在左侧输入框中输入阳离子脂质的 SMILES 字符串 + # 2. **选择目标器官**: 选择您希望优化的器官靶向 + # 3. **点击优选**: 系统将自动搜索最优配方组合 + # 4. **查看结果**: 右侧将显示 Top-20 优选配方 + # 5. **导出数据**: 点击导出按钮将结果保存为 CSV 文件 - ### 优化参数 + # ### 优化参数 - 系统会优化以下配方参数: - - **阳离子脂质/mRNA 比例**: 0.05 - 0.30 - - **阳离子脂质 mol 比例**: 0.05 - 0.80 - - **磷脂 mol 比例**: 0.00 - 0.80 - - **胆固醇 mol 比例**: 0.00 - 0.80 - - **PEG 脂质 mol 比例**: 0.00 - 0.05 - - **辅助脂质**: DOPE / DSPC / DOTAP - - **给药途径**: 静脉注射 / 肌肉注射 + # 系统会优化以下配方参数: + # - **阳离子脂质/mRNA 比例**: 0.05 - 0.30 + # - **阳离子脂质 mol 比例**: 0.05 - 0.80 + # - **磷脂 mol 比例**: 0.00 - 0.80 + # - **胆固醇 mol 比例**: 0.00 - 0.80 + # - **PEG 脂质 mol 比例**: 0.00 - 0.05 + # - **辅助脂质**: DOPE / DSPC / DOTAP + # - **给药途径**: 静脉注射 / 肌肉注射 - ### 约束条件 + # ### 约束条件 - mol 比例之和 = 1 (阳离子脂质 + 磷脂 + 胆固醇 + PEG 脂质) - """) + # mol 比例之和 = 1 (阳离子脂质 + 磷脂 + 胆固醇 + PEG 脂质) + # """) if __name__ == "__main__": diff --git a/app/optimize.py b/app/optimize.py index f383af1..8c80937 100644 --- a/app/optimize.py +++ b/app/optimize.py @@ -41,14 +41,73 @@ app = typer.Typer() # 可用的目标器官 AVAILABLE_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"] -# comp token 参数范围 -COMP_PARAM_RANGES = { - "Cationic_Lipid_to_mRNA_weight_ratio": (0.05, 0.30), - "Cationic_Lipid_Mol_Ratio": (0.05, 0.80), - "Phospholipid_Mol_Ratio": (0.00, 0.80), - "Cholesterol_Mol_Ratio": (0.00, 0.80), - "PEG_Lipid_Mol_Ratio": (0.00, 0.05), -} + +@dataclass +class CompRanges: + """组分参数范围配置""" + # 阳离子脂质/mRNA 重量比 + weight_ratio_min: float = 0.05 + weight_ratio_max: float = 0.30 + # 阳离子脂质 mol 比例 + cationic_mol_min: float = 0.05 + cationic_mol_max: float = 0.80 + # 磷脂 mol 比例 + phospholipid_mol_min: float = 0.00 + phospholipid_mol_max: float = 0.80 + # 胆固醇 mol 比例 + cholesterol_mol_min: float = 0.00 + cholesterol_mol_max: float = 0.80 + # PEG 脂质 mol 比例 + peg_mol_min: float = 0.00 + peg_mol_max: float = 0.05 + + def to_dict(self) -> Dict: + """转换为字典""" + return { + "weight_ratio": (self.weight_ratio_min, self.weight_ratio_max), + "cationic_mol": (self.cationic_mol_min, self.cationic_mol_max), + "phospholipid_mol": (self.phospholipid_mol_min, self.phospholipid_mol_max), + "cholesterol_mol": (self.cholesterol_mol_min, self.cholesterol_mol_max), + "peg_mol": (self.peg_mol_min, self.peg_mol_max), + } + + def get_validation_error(self) -> Optional[str]: + """ + 验证范围是否合理,返回错误信息(如果有)。 + + Returns: + 错误信息字符串,如果验证通过则返回 None + """ + # 检查各范围是否有效(最小值不能大于最大值) + if self.weight_ratio_min > self.weight_ratio_max: + return f"阳离子脂质/mRNA重量比:最小值({self.weight_ratio_min})不能大于最大值({self.weight_ratio_max})" + if self.cationic_mol_min > self.cationic_mol_max: + return f"阳离子脂质mol比例:最小值({self.cationic_mol_min})不能大于最大值({self.cationic_mol_max})" + if self.phospholipid_mol_min > self.phospholipid_mol_max: + return f"磷脂mol比例:最小值({self.phospholipid_mol_min})不能大于最大值({self.phospholipid_mol_max})" + if self.cholesterol_mol_min > self.cholesterol_mol_max: + return f"胆固醇mol比例:最小值({self.cholesterol_mol_min})不能大于最大值({self.cholesterol_mol_max})" + if self.peg_mol_min > self.peg_mol_max: + return f"PEG脂质mol比例:最小值({self.peg_mol_min})不能大于最大值({self.peg_mol_max})" + + # 检查 mol ratio 是否可能加起来为 1 + min_sum = self.cationic_mol_min + self.phospholipid_mol_min + self.cholesterol_mol_min + self.peg_mol_min + max_sum = self.cationic_mol_max + self.phospholipid_mol_max + self.cholesterol_mol_max + self.peg_mol_max + + if min_sum > 1.0: + return f"mol比例最小值之和({min_sum:.2f})超过100%,无法生成有效配方" + if max_sum < 1.0: + return f"mol比例最大值之和({max_sum:.2f})不足100%,无法生成有效配方" + + return None + + def validate(self) -> bool: + """验证范围是否合理(至少存在一个可行解)""" + return self.get_validation_error() is None + + +# 默认组分范围 +DEFAULT_COMP_RANGES = CompRanges() # 最小 step size MIN_STEP_SIZE = 0.01 @@ -56,12 +115,153 @@ MIN_STEP_SIZE = 0.01 # 迭代策略:每个迭代的 step_size ITERATION_STEP_SIZES = [0.10, 0.02, 0.01] -# Helper lipid 选项 -HELPER_LIPID_OPTIONS = ["DOPE", "DSPC", "DOTAP"] +# Helper lipid 选项(不包含 DOTAP) +HELPER_LIPID_OPTIONS = ["DOPE", "DSPC"] # Route of administration 选项 ROUTE_OPTIONS = ["intravenous", "intramuscular"] +# quantified_delivery 归一化常量(按给药途径) +DELIVERY_NORM = { + "intravenous": {"min": -0.798559291, "max": 4.497814051056962}, + "intramuscular": {"min": -0.794912427, "max": 10.220042980012716}, +} + + +@dataclass +class ScoringWeights: + """ + 评分权重配置。 + + 总分 = biodist_score + delivery_score + size_score + ee_score + pdi_score + toxic_score + 各项计算方式参见 SCORE.md。 + 默认值:仅按目标器官 biodistribution 排序(向后兼容)。 + """ + # 回归任务权重 + biodist_weight: float = 1.0 # score = biodist_value * weight + delivery_weight: float = 0.0 # score = normalize(delivery, route) * weight + size_weight: float = 0.0 # score = (1 if 80<=size<=150 else 0) * weight + # 分类任务:per-class 权重(预测为该类时,得分 = 对应权重) + ee_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) # EE class 0, 1, 2 + pdi_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0, 0.0]) # PDI class 0, 1, 2, 3 + toxic_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0]) # Toxic class 0, 1 + + +# 默认评分权重(仅按 biodist 排序) +DEFAULT_SCORING_WEIGHTS = ScoringWeights() + + +def compute_formulation_score( + f: 'Formulation', + organ: str, + weights: ScoringWeights, +) -> float: + """ + 计算单个 Formulation 的综合评分。 + + Args: + f: 配方对象 + organ: 目标器官 + weights: 评分权重 + + Returns: + 综合评分 + """ + score = 0.0 + + # 1. biodistribution + score += f.get_biodist(organ) * weights.biodist_weight + + # 2. quantified_delivery(按给药途径归一化到 [0, 1]) + if f.quantified_delivery is not None and weights.delivery_weight != 0: + norm = DELIVERY_NORM.get(f.route, DELIVERY_NORM["intravenous"]) + d_range = norm["max"] - norm["min"] + if d_range > 0: + delivery_normalized = (f.quantified_delivery - norm["min"]) / d_range + delivery_normalized = max(0.0, min(1.0, delivery_normalized)) + else: + delivery_normalized = 0.0 + score += delivery_normalized * weights.delivery_weight + + # 3. size(60-150nm 为理想范围) + if f.size is not None and weights.size_weight != 0: + if 60 <= f.size <= 150: + score += 1.0 * weights.size_weight + + # 4. EE 分类 + if f.ee_class is not None and 0 <= f.ee_class < len(weights.ee_class_weights): + score += weights.ee_class_weights[f.ee_class] + + # 5. PDI 分类 + if f.pdi_class is not None and 0 <= f.pdi_class < len(weights.pdi_class_weights): + score += weights.pdi_class_weights[f.pdi_class] + + # 6. 毒性分类 + if f.toxic_class is not None and 0 <= f.toxic_class < len(weights.toxic_class_weights): + score += weights.toxic_class_weights[f.toxic_class] + + return score + + +def compute_df_score( + df: pd.DataFrame, + organ: str, + weights: ScoringWeights, +) -> pd.Series: + """ + 为 DataFrame 中的所有行计算综合评分。 + + Args: + df: 包含预测结果的 DataFrame + organ: 目标器官 + weights: 评分权重 + + Returns: + 评分 Series + """ + score = pd.Series(0.0, index=df.index) + + # 1. biodistribution + pred_col = f"pred_Biodistribution_{organ}" + if pred_col in df.columns: + score += df[pred_col] * weights.biodist_weight + + # 2. quantified_delivery(按给药途径归一化) + if weights.delivery_weight != 0 and "pred_delivery" in df.columns: + for route_name, norm in DELIVERY_NORM.items(): + mask = df["_route"] == route_name + if mask.any(): + d_range = norm["max"] - norm["min"] + if d_range > 0: + delivery_normalized = (df.loc[mask, "pred_delivery"] - norm["min"]) / d_range + delivery_normalized = delivery_normalized.clip(0.0, 1.0) + score.loc[mask] += delivery_normalized * weights.delivery_weight + + # 3. size + if weights.size_weight != 0 and "pred_size" in df.columns: + size_ok = (df["pred_size"] >= 60) & (df["pred_size"] <= 150) + score += size_ok.astype(float) * weights.size_weight + + # 4. EE 分类 + if "pred_ee_class" in df.columns: + for cls, w in enumerate(weights.ee_class_weights): + if w != 0: + score += (df["pred_ee_class"] == cls).astype(float) * w + + # 5. PDI 分类 + if "pred_pdi_class" in df.columns: + for cls, w in enumerate(weights.pdi_class_weights): + if w != 0: + score += (df["pred_pdi_class"] == cls).astype(float) * w + + # 6. 毒性分类 + if "pred_toxic_class" in df.columns: + for cls, w in enumerate(weights.toxic_class_weights): + if w != 0: + score += (df["pred_toxic_class"] == cls).astype(float) * w + + return score + @dataclass class Formulation: @@ -77,6 +277,12 @@ class Formulation: route: str = "intravenous" # 预测结果(填充后) biodist_predictions: Dict[str, float] = field(default_factory=dict) + # 额外预测值 + quantified_delivery: Optional[float] = None + size: Optional[float] = None + pdi_class: Optional[int] = None # PDI 分类 (0-3) + ee_class: Optional[int] = None # EE 分类 (0-2) + toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒) def to_dict(self) -> Dict: """转换为字典""" @@ -94,6 +300,18 @@ class Formulation: """获取指定器官的 biodistribution 预测值""" col = f"Biodistribution_{organ}" return self.biodist_predictions.get(col, 0.0) + + def unique_key(self) -> tuple: + """生成唯一标识键,用于去重""" + return ( + round(self.cationic_lipid_to_mrna_ratio, 4), + round(self.cationic_lipid_mol_ratio, 4), + round(self.phospholipid_mol_ratio, 4), + round(self.cholesterol_mol_ratio, 4), + round(self.peg_lipid_mol_ratio, 4), + self.helper_lipid, + self.route, + ) def generate_grid_values( @@ -124,20 +342,38 @@ def generate_grid_values( return sorted(set(values)) -def generate_initial_grid(step_size: float) -> List[Tuple[float, float, float, float, float]]: +def generate_initial_grid( + step_size: float, + comp_ranges: CompRanges = None, +) -> List[Tuple[float, float, float, float, float]]: """ 生成初始搜索网格(满足 mol ratio 和为 1 的约束)。 + Args: + step_size: 搜索步长 + comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) + Returns: List of (cationic_ratio, cationic_mol, phospholipid_mol, cholesterol_mol, peg_mol) """ + if comp_ranges is None: + comp_ranges = DEFAULT_COMP_RANGES + grid = [] # Cationic_Lipid_to_mRNA_weight_ratio - weight_ratios = np.arange(0.05, 0.31, step_size) + weight_ratios = np.arange( + comp_ranges.weight_ratio_min, + comp_ranges.weight_ratio_max + 0.001, + step_size + ) - # PEG: 单独处理,范围很小 - peg_values = np.arange(0.00, 0.06, MIN_STEP_SIZE) # PEG 始终用 0.01 步长 + # PEG: 单独处理,范围很小,始终用最小步长 + peg_values = np.arange( + comp_ranges.peg_mol_min, + comp_ranges.peg_mol_max + 0.001, + MIN_STEP_SIZE + ) # 其他三个 mol ratio 需要满足和为 1 - PEG mol_step = step_size @@ -146,11 +382,13 @@ def generate_initial_grid(step_size: float) -> List[Tuple[float, float, float, f for peg in peg_values: remaining = 1.0 - peg # 生成满足约束的组合 - for cationic_mol in np.arange(0.05, min(0.81, remaining + 0.001), mol_step): - for phospholipid_mol in np.arange(0.00, min(0.81, remaining - cationic_mol + 0.001), mol_step): + cationic_max = min(comp_ranges.cationic_mol_max, remaining) + 0.001 + for cationic_mol in np.arange(comp_ranges.cationic_mol_min, cationic_max, mol_step): + phospholipid_max = min(comp_ranges.phospholipid_mol_max, remaining - cationic_mol) + 0.001 + for phospholipid_mol in np.arange(comp_ranges.phospholipid_mol_min, phospholipid_max, mol_step): cholesterol_mol = remaining - cationic_mol - phospholipid_mol # 检查约束 - if 0.00 <= cholesterol_mol <= 0.80: + if (comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max): grid.append(( round(weight_ratio, 4), round(cationic_mol, 4), @@ -166,6 +404,7 @@ def generate_refined_grid( seeds: List[Formulation], step_size: float, radius: int = 2, + comp_ranges: CompRanges = None, ) -> List[Tuple[float, float, float, float, float]]: """ 围绕种子点生成精细化网格。 @@ -174,29 +413,37 @@ def generate_refined_grid( seeds: 种子配方列表 step_size: 步长 radius: 扩展半径 + comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) Returns: 新的网格点列表 """ + if comp_ranges is None: + comp_ranges = DEFAULT_COMP_RANGES + grid_set = set() for seed in seeds: # Weight ratio weight_ratios = generate_grid_values( - seed.cationic_lipid_to_mrna_ratio, step_size, 0.05, 0.30, radius + seed.cationic_lipid_to_mrna_ratio, step_size, + comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius ) # PEG (始终用最小步长) peg_values = generate_grid_values( - seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, 0.00, 0.05, radius + seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, + comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius ) # Mol ratios cationic_mols = generate_grid_values( - seed.cationic_lipid_mol_ratio, step_size, 0.05, 0.80, radius + seed.cationic_lipid_mol_ratio, step_size, + comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius ) phospholipid_mols = generate_grid_values( - seed.phospholipid_mol_ratio, step_size, 0.00, 0.80, radius + seed.phospholipid_mol_ratio, step_size, + comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius ) for weight_ratio in weight_ratios: @@ -206,10 +453,10 @@ def generate_refined_grid( for phospholipid_mol in phospholipid_mols: cholesterol_mol = remaining - cationic_mol - phospholipid_mol # 检查约束 - if (0.05 <= cationic_mol <= 0.80 and - 0.00 <= phospholipid_mol <= 0.80 and - 0.00 <= cholesterol_mol <= 0.80 and - 0.00 <= peg <= 0.05): + if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and + comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and + comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and + comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max): grid_set.add(( round(weight_ratio, 4), round(cationic_mol, 4), @@ -283,14 +530,14 @@ def create_dataframe_from_formulations( return pd.DataFrame(rows) -def predict_biodist( +def predict_all( model: torch.nn.Module, df: pd.DataFrame, device: torch.device, batch_size: int = 256, ) -> pd.DataFrame: """ - 使用模型预测 biodistribution。 + 使用模型预测所有输出(biodistribution、size、delivery、PDI、EE)。 Returns: 添加了预测列的 DataFrame @@ -301,6 +548,11 @@ def predict_biodist( ) all_biodist_preds = [] + all_size_preds = [] + all_delivery_preds = [] + all_pdi_preds = [] + all_ee_preds = [] + all_toxic_preds = [] with torch.no_grad(): for batch in dataloader: @@ -312,20 +564,53 @@ def predict_biodist( # biodist 输出是 softmax 后的概率分布 [B, 7] biodist_pred = outputs["biodist"].cpu().numpy() all_biodist_preds.append(biodist_pred) + + # size 和 delivery 是回归值 + all_size_preds.append(outputs["size"].squeeze(-1).cpu().numpy()) + all_delivery_preds.append(outputs["delivery"].squeeze(-1).cpu().numpy()) + + # PDI、EE 和 toxic 是分类,取 argmax + all_pdi_preds.append(outputs["pdi"].argmax(dim=-1).cpu().numpy()) + all_ee_preds.append(outputs["ee"].argmax(dim=-1).cpu().numpy()) + all_toxic_preds.append(outputs["toxic"].argmax(dim=-1).cpu().numpy()) biodist_preds = np.concatenate(all_biodist_preds, axis=0) + size_preds = np.concatenate(all_size_preds, axis=0) + delivery_preds = np.concatenate(all_delivery_preds, axis=0) + pdi_preds = np.concatenate(all_pdi_preds, axis=0) + ee_preds = np.concatenate(all_ee_preds, axis=0) + toxic_preds = np.concatenate(all_toxic_preds, axis=0) # 添加到 DataFrame for i, col in enumerate(TARGET_BIODIST): df[f"pred_{col}"] = biodist_preds[:, i] + # size 模型输出为 log(size),转换回真实粒径 (nm) + df["pred_size"] = np.exp(size_preds) + df["pred_delivery"] = delivery_preds + df["pred_pdi_class"] = pdi_preds + df["pred_ee_class"] = ee_preds + df["pred_toxic_class"] = toxic_preds + return df +# 保持向后兼容 +def predict_biodist( + model: torch.nn.Module, + df: pd.DataFrame, + device: torch.device, + batch_size: int = 256, +) -> pd.DataFrame: + """向后兼容的别名""" + return predict_all(model, df, device, batch_size) + + def select_top_k( df: pd.DataFrame, organ: str, k: int = 20, + scoring_weights: Optional[ScoringWeights] = None, ) -> List[Formulation]: """ 选择 top-k 配方。 @@ -334,16 +619,18 @@ def select_top_k( df: 包含预测结果的 DataFrame organ: 目标器官 k: 选择数量 + scoring_weights: 评分权重(默认仅按 biodist 排序) Returns: Top-k 配方列表 """ - pred_col = f"pred_Biodistribution_{organ}" - if pred_col not in df.columns: - raise ValueError(f"Prediction column {pred_col} not found") + if scoring_weights is None: + scoring_weights = DEFAULT_SCORING_WEIGHTS - # 排序并去重 - df_sorted = df.sort_values(pred_col, ascending=False) + # 计算综合评分并排序 + df = df.copy() + df["_composite_score"] = compute_df_score(df, organ, scoring_weights) + df_sorted = df.sort_values("_composite_score", ascending=False) # 创建配方对象 formulations = [] @@ -373,6 +660,12 @@ def select_top_k( biodist_predictions={ col: row[f"pred_{col}"] for col in TARGET_BIODIST }, + # 额外预测值 + quantified_delivery=row.get("pred_delivery"), + size=row.get("pred_size"), + pdi_class=int(row.get("pred_pdi_class")) if row.get("pred_pdi_class") is not None else None, + ee_class=int(row.get("pred_ee_class")) if row.get("pred_ee_class") is not None else None, + toxic_class=int(row.get("pred_toxic_class")) if row.get("pred_toxic_class") is not None else None, ) formulations.append(formulation) @@ -382,72 +675,242 @@ def select_top_k( return formulations +def generate_single_seed_grid( + seed: Formulation, + step_size: float, + radius: int = 2, + comp_ranges: CompRanges = None, +) -> List[Tuple[float, float, float, float, float]]: + """ + 为单个种子点生成邻域网格。 + + Args: + seed: 种子配方 + step_size: 步长 + radius: 扩展半径 + comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) + + Returns: + 网格点列表 + """ + if comp_ranges is None: + comp_ranges = DEFAULT_COMP_RANGES + + grid_set = set() + + # Weight ratio + weight_ratios = generate_grid_values( + seed.cationic_lipid_to_mrna_ratio, step_size, + comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius + ) + + # PEG (始终用最小步长) + peg_values = generate_grid_values( + seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, + comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius + ) + + # Mol ratios + cationic_mols = generate_grid_values( + seed.cationic_lipid_mol_ratio, step_size, + comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius + ) + phospholipid_mols = generate_grid_values( + seed.phospholipid_mol_ratio, step_size, + comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius + ) + + for weight_ratio in weight_ratios: + for peg in peg_values: + remaining = 1.0 - peg + for cationic_mol in cationic_mols: + for phospholipid_mol in phospholipid_mols: + cholesterol_mol = remaining - cationic_mol - phospholipid_mol + # 检查约束 + if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and + comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and + comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and + comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max): + grid_set.add(( + round(weight_ratio, 4), + round(cationic_mol, 4), + round(phospholipid_mol, 4), + round(cholesterol_mol, 4), + round(peg, 4), + )) + + return list(grid_set) + + def optimize( smiles: str, organ: str, model: torch.nn.Module, device: torch.device, top_k: int = 20, + num_seeds: Optional[int] = None, + top_per_seed: int = 1, + step_sizes: Optional[List[float]] = None, + comp_ranges: Optional[CompRanges] = None, + routes: Optional[List[str]] = None, + scoring_weights: Optional[ScoringWeights] = None, batch_size: int = 256, ) -> List[Formulation]: """ - 执行配方优化。 + 执行配方优化(层级搜索策略)。 + + 采用层级搜索策略: + 1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点 + 2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优 + 3. 这样可以保持搜索的多样性,避免结果集中在单一区域 Args: smiles: SMILES 字符串 organ: 目标器官 model: 训练好的模型 device: 计算设备 - top_k: 每轮保留的最优配方数 + top_k: 最终返回的最优配方数 + num_seeds: 第一次迭代后保留的种子点数量(默认为 top_k * 5) + top_per_seed: 每个种子点的邻域搜索后保留的局部最优点数量 + step_sizes: 每轮迭代的步长列表(默认为 [0.10, 0.02, 0.01]) + comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES) + routes: 给药途径列表(默认使用 ROUTE_OPTIONS) + scoring_weights: 评分权重配置(默认仅按 biodist 排序) batch_size: 预测批次大小 Returns: 最终 top-k 配方列表 """ + # 默认 num_seeds 为 top_k * 5 + if num_seeds is None: + num_seeds = top_k * 5 + + # 默认步长 + if step_sizes is None: + step_sizes = ITERATION_STEP_SIZES + + # 默认组分范围 + if comp_ranges is None: + comp_ranges = DEFAULT_COMP_RANGES + + # 默认给药途径 + if routes is None: + routes = ROUTE_OPTIONS + + # 默认评分权重 + if scoring_weights is None: + scoring_weights = DEFAULT_SCORING_WEIGHTS + + # 评分函数(用于 Formulation 对象排序) + def _score(f: Formulation) -> float: + return compute_formulation_score(f, organ, scoring_weights) + logger.info(f"Starting optimization for organ: {organ}") logger.info(f"SMILES: {smiles}") + logger.info(f"Strategy: num_seeds={num_seeds}, top_per_seed={top_per_seed}, top_k={top_k}") + logger.info(f"Step sizes: {step_sizes}") + logger.info(f"Routes: {routes}") + logger.info(f"Scoring weights: biodist={scoring_weights.biodist_weight}, delivery={scoring_weights.delivery_weight}, size={scoring_weights.size_weight}") + logger.info(f"Comp ranges: {comp_ranges.to_dict()}") seeds = None - for iteration, step_size in enumerate(ITERATION_STEP_SIZES): + for iteration, step_size in enumerate(step_sizes): logger.info(f"\n{'='*60}") - logger.info(f"Iteration {iteration + 1}/{len(ITERATION_STEP_SIZES)}, step_size={step_size}") + logger.info(f"Iteration {iteration + 1}/{len(step_sizes)}, step_size={step_size}") logger.info(f"{'='*60}") - # 生成网格 if seeds is None: - # 第一次迭代:生成完整初始网格 - logger.info("Generating initial grid...") - grid = generate_initial_grid(step_size) + # ==================== 第一次迭代:全局稀疏搜索 ==================== + logger.info("Generating initial grid (global sparse search)...") + grid = generate_initial_grid(step_size, comp_ranges) + + logger.info(f"Grid size: {len(grid)} comp combinations") + + # 扩展到所有 helper lipid 和 route 组合 + total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(routes) + logger.info(f"Total combinations: {total_combinations}") + + # 创建 DataFrame + df = create_dataframe_from_formulations( + smiles, grid, HELPER_LIPID_OPTIONS, routes + ) + + # 预测 + logger.info("Running predictions...") + df = predict_biodist(model, df, device, batch_size) + + # 选择 top num_seeds 个种子点 + seeds = select_top_k(df, organ, num_seeds, scoring_weights) + + logger.info(f"Selected {len(seeds)} seeds for next iteration") + else: - # 后续迭代:围绕种子点精细化 - logger.info(f"Generating refined grid around {len(seeds)} seeds...") - grid = generate_refined_grid(seeds, step_size, radius=2) - - logger.info(f"Grid size: {len(grid)} comp combinations") - - # 扩展到所有 helper lipid 和 route 组合 - total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(ROUTE_OPTIONS) - logger.info(f"Total combinations: {total_combinations}") - - # 创建 DataFrame - df = create_dataframe_from_formulations( - smiles, grid, HELPER_LIPID_OPTIONS, ROUTE_OPTIONS - ) - - # 预测 - logger.info("Running predictions...") - df = predict_biodist(model, df, device, batch_size) - - # 选择 top-k - seeds = select_top_k(df, organ, top_k) + # ==================== 后续迭代:层级局部搜索 ==================== + # 对每个种子点分别搜索,各自保留局部最优 + logger.info(f"Hierarchical local search around {len(seeds)} seeds...") + + all_local_best = [] + + for seed_idx, seed in enumerate(seeds): + # 为当前种子点生成邻域网格 + local_grid = generate_single_seed_grid(seed, step_size, radius=2, comp_ranges=comp_ranges) + + if len(local_grid) == 0: + # 如果没有新的网格点,保留原种子 + all_local_best.append(seed) + continue + + # 创建 DataFrame + df = create_dataframe_from_formulations( + smiles, local_grid, [seed.helper_lipid], [seed.route] + ) + + # 预测 + df = predict_biodist(model, df, device, batch_size) + + # 选择该种子邻域内的 top top_per_seed 个局部最优 + local_top = select_top_k(df, organ, top_per_seed, scoring_weights) + all_local_best.extend(local_top) + + if seed_idx == 0 or (seed_idx + 1) % 5 == 0: + logger.info(f" Seed {seed_idx + 1}/{len(seeds)}: local grid size={len(local_grid)}, " + f"local best score={_score(local_top[0]):.4f}") + + # 更新种子为所有局部最优点(去重) + seen_keys = set() + unique_local_best = [] + # 先按综合评分排序,确保保留最优的 + all_local_best_sorted = sorted(all_local_best, key=_score, reverse=True) + for f in all_local_best_sorted: + key = f.unique_key() + if key not in seen_keys: + seen_keys.add(key) + unique_local_best.append(f) + + seeds = unique_local_best + logger.info(f"Collected {len(seeds)} unique local best formulations (from {len(all_local_best)} candidates)") # 显示当前最优 - best = seeds[0] - logger.info(f"Current best Biodistribution_{organ}: {best.get_biodist(organ):.4f}") + best = max(seeds, key=_score) + logger.info(f"Current best score: {_score(best):.4f} (biodist_{organ}={best.get_biodist(organ):.4f})") logger.info(f"Best formulation: {best.to_dict()}") - return seeds + # 最终去重、按综合评分排序并返回 top_k + seeds_sorted = sorted(seeds, key=_score, reverse=True) + + # 去重:保留每个唯一配方中得分最高的(已排序,所以第一个出现的就是最高的) + seen_keys = set() + unique_results = [] + for f in seeds_sorted: + key = f.unique_key() + if key not in seen_keys: + seen_keys.add(key) + unique_results.append(f) + + logger.info(f"Final results: {len(unique_results)} unique formulations (from {len(seeds)} candidates)") + + return unique_results[:top_k] def format_results(formulations: List[Formulation], organ: str) -> pd.DataFrame: @@ -481,33 +944,54 @@ def main( help="Output CSV path (optional)" ), top_k: int = typer.Option(20, "--top-k", "-k", help="Number of top formulations to return"), + num_seeds: Optional[int] = typer.Option(None, "--num-seeds", "-n", help="Number of seed points from first iteration (default: top_k * 5)"), + top_per_seed: int = typer.Option(1, "--top-per-seed", "-t", help="Number of local best to keep per seed"), + step_sizes: Optional[str] = typer.Option(None, "--step-sizes", "-S", help="Comma-separated step sizes (e.g., '0.10,0.02,0.01')"), batch_size: int = typer.Option(256, "--batch-size", "-b", help="Prediction batch size"), device: str = typer.Option("cuda" if torch.cuda.is_available() else "cpu", "--device", "-d", help="Device"), ): """ 配方优化程序:寻找最大化目标器官 Biodistribution 的最优 LNP 配方。 + 采用层级搜索策略: + 1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点 + 2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优 + 3. 这样可以保持搜索的多样性,避免结果集中在单一区域 + 示例: python -m app.optimize --smiles "CC(C)..." --organ liver - python -m app.optimize -s "CC(C)..." -o spleen -k 10 + python -m app.optimize -s "CC(C)..." -o spleen -k 10 -n 30 -t 2 + python -m app.optimize -s "CC(C)..." -o liver -S "0.10,0.05,0.02" """ # 验证器官 if organ not in AVAILABLE_ORGANS: logger.error(f"Invalid organ: {organ}. Available: {AVAILABLE_ORGANS}") raise typer.Exit(1) + # 解析步长 + parsed_step_sizes = None + if step_sizes: + try: + parsed_step_sizes = [float(s.strip()) for s in step_sizes.split(",")] + except ValueError: + logger.error(f"Invalid step sizes format: {step_sizes}") + raise typer.Exit(1) + # 加载模型 logger.info(f"Loading model from {model_path}") device = torch.device(device) model = load_model(model_path, device) - # 执行优化 + # 执行优化(层级搜索策略) results = optimize( smiles=smiles, organ=organ, model=model, device=device, top_k=top_k, + num_seeds=num_seeds, + top_per_seed=top_per_seed, + step_sizes=parsed_step_sizes, batch_size=batch_size, ) diff --git a/data/external/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt b/data/external/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt index f830a1b..b8a65fb 100644 Binary files a/data/external/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt and b/data/external/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt differ diff --git a/data/external/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt b/data/external/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt index ef8c294..97ac616 100644 Binary files a/data/external/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt and b/data/external/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt differ diff --git a/data/external/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt b/data/external/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt index f97a0c6..6d6dc04 100644 Binary files a/data/external/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt and b/data/external/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt differ diff --git a/data/external/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt b/data/external/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt index 3f54b88..95ae8c9 100644 Binary files a/data/external/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt and b/data/external/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt differ diff --git a/data/external/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt b/data/external/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt index 3bfbe0a..c6393be 100644 Binary files a/data/external/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt and b/data/external/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt differ diff --git a/data/raw/.~internal_deleted_uncorrected.xlsx b/data/raw/.~internal_deleted_uncorrected.xlsx new file mode 100644 index 0000000..8b5617d Binary files /dev/null and b/data/raw/.~internal_deleted_uncorrected.xlsx differ diff --git a/docker-compose-gpu.yml b/docker-compose-gpu.yml new file mode 100644 index 0000000..01cc47b --- /dev/null +++ b/docker-compose-gpu.yml @@ -0,0 +1,54 @@ +services: + # FastAPI 后端服务 + api: + build: + context: . + dockerfile: Dockerfile + target: api + container_name: lnp-api + environment: + - MODEL_PATH=/app/models/final/model.pt + volumes: + # 挂载模型目录以便更新模型 + - ./models/final:/app/models/final:ro + - ./models/mpnn:/app/models/mpnn:ro + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + # Streamlit 前端服务 + streamlit: + build: + context: . + dockerfile: Dockerfile + target: streamlit + container_name: lnp-streamlit + ports: + - "8501:8501" + environment: + - API_URL=http://api:8000 + depends_on: + api: + condition: service_started + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8501/_stcore/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + +networks: + default: + name: lnp-network diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..9185138 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,49 @@ +services: + # FastAPI 后端服务 + api: + build: + context: . + dockerfile: Dockerfile + target: api + container_name: lnp-api + environment: + - MODEL_PATH=/app/models/final/model.pt + - DEVICE=cpu + volumes: + # 挂载模型目录以便更新模型 + - ./models/final:/app/models/final:ro + - ./models/mpnn:/app/models/mpnn:ro + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s + + # Streamlit 前端服务 + streamlit: + build: + context: . + dockerfile: Dockerfile + target: streamlit + container_name: lnp-streamlit + ports: + - "8501:8501" + environment: + - API_URL=http://api:8000 + depends_on: + api: + condition: service_started + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8501/_stcore/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + +networks: + default: + name: lnp-network + diff --git a/lnp_ml/modeling/final_train_optuna_cv.py b/lnp_ml/modeling/final_train_optuna_cv.py new file mode 100644 index 0000000..0c039c4 --- /dev/null +++ b/lnp_ml/modeling/final_train_optuna_cv.py @@ -0,0 +1,612 @@ +""" +最终训练脚本:3-fold Optuna 调参 + 全量数据训练 + +1. 使用全量数据做 3-fold StratifiedKFold Optuna 超参搜索 +2. 固定最优超参后,使用全量数据训练(不使用 early-stopping) +3. 使用 CosineAnnealingLR + 可选 SWA 防止过拟合 + +使用方法: + python -m lnp_ml.modeling.final_train_optuna_cv + +或通过 Makefile: + make final_optuna DEVICE=cuda +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import DataLoader, Subset +from sklearn.model_selection import StratifiedKFold +from loguru import logger +import typer + +try: + import optuna + from optuna.samplers import TPESampler +except ImportError: + optuna = None + TPESampler = None + +from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR +from lnp_ml.dataset import ( + LNPDataset, + collate_fn, + process_dataframe, + TARGET_CLASSIFICATION_PDI, + TARGET_CLASSIFICATION_EE, + TARGET_TOXIC, +) +from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN +from lnp_ml.modeling.trainer_balanced import ( + ClassWeights, + LossWeightsBalanced, + compute_class_weights_from_loader, + train_with_early_stopping, + train_fixed_epochs, +) +from lnp_ml.modeling.visualization import plot_multitask_loss_curves + +# MPNN ensemble 默认路径 +DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON" + +app = typer.Typer() + + +# ============ CompositeStrata 复合分层标签(复用 nested_cv_optuna 的逻辑) ============ + +def build_composite_strata( + df: pd.DataFrame, + min_stratum_count: int = 5, +) -> Tuple[np.ndarray, Dict]: + """ + 构建复合分层标签(toxic × PDI × EE)。 + + Args: + df: 处理后的 DataFrame + min_stratum_count: 每个 stratum 最少样本数,低于此值合并为 RARE + + Returns: + (strata_array, strata_info) + """ + n = len(df) + strata_labels = [] + + for i in range(n): + # Toxic stratum + if TARGET_TOXIC in df.columns: + toxic_val = df[TARGET_TOXIC].iloc[i] + if pd.notna(toxic_val) and toxic_val >= 0: + toxic_str = str(int(toxic_val)) + else: + toxic_str = "NA" + else: + toxic_str = "NA" + + # PDI stratum + if all(col in df.columns for col in TARGET_CLASSIFICATION_PDI): + pdi_vals = df[TARGET_CLASSIFICATION_PDI].iloc[i].values + if pdi_vals.sum() > 0: + pdi_str = str(int(np.argmax(pdi_vals))) + else: + pdi_str = "NA" + else: + pdi_str = "NA" + + # EE stratum + if all(col in df.columns for col in TARGET_CLASSIFICATION_EE): + ee_vals = df[TARGET_CLASSIFICATION_EE].iloc[i].values + if ee_vals.sum() > 0: + ee_str = str(int(np.argmax(ee_vals))) + else: + ee_str = "NA" + else: + ee_str = "NA" + + strata_labels.append(f"T{toxic_str}|P{pdi_str}|E{ee_str}") + + # 统计各 stratum 的样本数 + unique_strata, counts = np.unique(strata_labels, return_counts=True) + strata_counts = dict(zip(unique_strata, counts)) + + # 将稀疏 strata 合并为 RARE + rare_strata = [s for s, c in strata_counts.items() if c < min_stratum_count] + + final_labels = [] + for label in strata_labels: + if label in rare_strata: + final_labels.append("RARE") + else: + final_labels.append(label) + + # 编码为整数 + unique_final, encoded = np.unique(final_labels, return_inverse=True) + + strata_info = { + "original_strata_counts": strata_counts, + "rare_strata": rare_strata, + "final_strata": list(unique_final), + "final_strata_counts": dict(zip(*np.unique(final_labels, return_counts=True))), + "n_rare_merged": sum(strata_counts[s] for s in rare_strata) if rare_strata else 0, + } + + logger.info(f"Built composite strata: {len(unique_final)} unique strata") + logger.info(f" Rare strata merged: {len(rare_strata)} types, {strata_info['n_rare_merged']} samples") + + return encoded.astype(np.int64), strata_info + + +# ============ 模型创建 ============ + +def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]: + """自动查找 MPNN ensemble 的 model.pt 文件。""" + model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt")) + if not model_paths: + raise FileNotFoundError(f"No model.pt files found in {base_dir}") + return [str(p) for p in model_paths] + + +def create_model( + d_model: int = 256, + num_heads: int = 8, + n_attn_layers: int = 4, + fusion_strategy: str = "attention", + head_hidden_dim: int = 128, + dropout: float = 0.1, + use_mpnn: bool = False, + mpnn_device: str = "cpu", +) -> Union[LNPModel, LNPModelWithoutMPNN]: + """创建模型""" + if use_mpnn: + ensemble_paths = find_mpnn_ensemble_paths() + return LNPModel( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + mpnn_ensemble_paths=ensemble_paths, + mpnn_device=mpnn_device, + ) + else: + return LNPModelWithoutMPNN( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + ) + + +# ============ 预训练权重加载 ============ + +def load_pretrain_weights_to_model( + model: Union[LNPModel, LNPModelWithoutMPNN], + pretrain_state_dict: Dict, + d_model: int, + pretrain_config: Dict, + load_delivery_head: bool = True, +) -> bool: + """ + 加载预训练权重到模型。 + + Returns: + 是否成功加载 + """ + if pretrain_config.get("d_model") != d_model: + logger.warning( + f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, " + f"current={d_model}. Skipping pretrain loading." + ) + return False + + model.load_pretrain_weights( + pretrain_state_dict=pretrain_state_dict, + load_delivery_head=load_delivery_head, + strict=False, + ) + return True + + +# ============ 3-fold Optuna 调参 ============ + +def run_optuna_cv( + full_dataset: LNPDataset, + strata: np.ndarray, + device: torch.device, + n_trials: int = 20, + epochs_per_trial: int = 30, + patience: int = 10, + batch_size: int = 32, + n_folds: int = 3, + use_mpnn: bool = False, + seed: int = 42, + study_path: Optional[Path] = None, + pretrain_state_dict: Optional[Dict] = None, + pretrain_config: Optional[Dict] = None, + load_delivery_head: bool = True, +) -> Tuple[Dict, int, optuna.Study]: + """ + 使用全量数据做 3-fold CV Optuna 超参搜索。 + + Args: + full_dataset: 完整数据集 + strata: 每个样本的分层标签 + device: 设备 + n_trials: Optuna 试验数 + epochs_per_trial: 每个试验的最大 epoch + patience: 早停耐心值 + batch_size: 批次大小 + n_folds: 折数 + use_mpnn: 是否使用 MPNN + seed: 随机种子 + study_path: 可选的 study 持久化路径 + pretrain_state_dict: 预训练权重 + pretrain_config: 预训练配置 + load_delivery_head: 是否加载 delivery head 权重 + + Returns: + (best_params, epoch_mean, study) + """ + if optuna is None: + raise ImportError("Optuna not installed. Run: pip install optuna") + + n_samples = len(full_dataset) + indices = np.arange(n_samples) + + def objective(trial: optuna.Trial) -> float: + # 采样超参数 + d_model = trial.suggest_categorical("d_model", [128, 256, 512]) + num_heads = trial.suggest_categorical("num_heads", [4, 8]) + n_attn_layers = trial.suggest_int("n_attn_layers", 2, 6) + fusion_strategy = trial.suggest_categorical( + "fusion_strategy", ["attention", "avg", "max"] + ) + head_hidden_dim = trial.suggest_categorical("head_hidden_dim", [64, 128, 256]) + dropout = trial.suggest_float("dropout", 0.05, 0.3) + lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True) + weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True) + + # 3-fold CV + cv = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed) + + fold_val_losses = [] + fold_best_epochs = [] + + for fold, (train_idx, val_idx) in enumerate(cv.split(indices, strata)): + # 创建 DataLoader + train_subset = Subset(full_dataset, train_idx.tolist()) + val_subset = Subset(full_dataset, val_idx.tolist()) + + train_loader = DataLoader( + train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn + ) + val_loader = DataLoader( + val_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn + ) + + # 计算类权重 + class_weights = compute_class_weights_from_loader(train_loader) + + # 创建模型 + model = create_model( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + use_mpnn=use_mpnn, + mpnn_device=device.type, + ) + + # 加载预训练权重 + if pretrain_state_dict is not None and pretrain_config is not None: + load_pretrain_weights_to_model( + model, pretrain_state_dict, d_model, pretrain_config, load_delivery_head + ) + + # 训练(带早停) + result = train_with_early_stopping( + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + lr=lr, + weight_decay=weight_decay, + epochs=epochs_per_trial, + patience=patience, + class_weights=class_weights, + ) + + fold_val_losses.append(result["best_val_loss"]) + fold_best_epochs.append(result["best_epoch"]) + + # 记录 epoch_mean 到 trial + epoch_mean = int(round(np.mean(fold_best_epochs))) + trial.set_user_attr("epoch_mean", epoch_mean) + trial.set_user_attr("fold_best_epochs", fold_best_epochs) + trial.set_user_attr("fold_val_losses", fold_val_losses) + + return np.mean(fold_val_losses) + + # 创建 study + storage = None + if study_path is not None: + storage = f"sqlite:///{study_path}" + + study = optuna.create_study( + direction="minimize", + sampler=TPESampler(seed=seed), + storage=storage, + study_name="final_optuna_cv", + load_if_exists=True, + ) + + study.optimize(objective, n_trials=n_trials, show_progress_bar=True) + + best_params = study.best_trial.params + epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial) + + logger.info(f"Best trial: {study.best_trial.number}") + logger.info(f"Best val_loss: {study.best_trial.value:.4f}") + logger.info(f"Best params: {best_params}") + logger.info(f"Epoch mean from best trial: {epoch_mean}") + + return best_params, epoch_mean, study + + +# ============ 主流程 ============ + +@app.command() +def main( + input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv", + output_dir: Path = MODELS_DIR / "final_optuna", + # CV 参数 + n_folds: int = 3, + min_stratum_count: int = 5, + seed: int = 42, + # Optuna 参数 + n_trials: int = 20, + epochs_per_trial: int = 30, + patience: int = 10, + # 训练参数 + batch_size: int = 32, + # 最终训练参数 + use_swa: bool = False, + swa_start_ratio: float = 0.75, + # 预训练权重 + init_from_pretrain: Optional[Path] = None, + load_delivery_head: bool = True, + # MPNN + use_mpnn: bool = False, + # 设备 + device: str = "cuda" if torch.cuda.is_available() else "cpu", +): + """ + 最终训练:3-fold Optuna 调参 + 全量数据训练。 + + 1. 使用全量数据做 3-fold StratifiedKFold Optuna 超参搜索 + 2. 固定最优超参后,使用全量数据训练(不使用 early-stopping) + 3. epoch 数使用 3-fold CV 中 best trial 的 epoch_mean + 4. 使用 CosineAnnealingLR,可选 SWA + + 使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重。 + """ + if optuna is None: + logger.error("Optuna not installed. Run: pip install optuna") + raise typer.Exit(1) + + logger.info(f"Using device: {device}") + device = torch.device(device) + + # 加载预训练权重(如果指定) + pretrain_state_dict = None + pretrain_config = None + if init_from_pretrain is not None: + if init_from_pretrain.exists(): + logger.info(f"Loading pretrain weights from {init_from_pretrain}") + checkpoint = torch.load(init_from_pretrain, map_location="cpu") + pretrain_state_dict = checkpoint["model_state_dict"] + pretrain_config = checkpoint.get("config", {}) + logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})") + else: + logger.warning(f"Pretrain checkpoint not found: {init_from_pretrain}, skipping") + + # 创建输出目录 + output_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Output directory: {output_dir}") + + # 加载数据 + logger.info(f"Loading data from {input_path}") + df = pd.read_csv(input_path) + logger.info(f"Loaded {len(df)} samples") + + # 处理数据 + logger.info("Processing dataframe...") + df = process_dataframe(df) + + # 构建复合分层标签 + logger.info("Building composite strata...") + strata, strata_info = build_composite_strata(df, min_stratum_count) + + # 保存 strata 信息 + with open(output_dir / "strata_info.json", "w") as f: + json.dump(strata_info, f, indent=2, default=str) + + # 创建完整数据集 + full_dataset = LNPDataset(df) + + # 运行 Optuna 调参 + logger.info(f"\nRunning {n_folds}-fold Optuna with {n_trials} trials...") + study_path = output_dir / "optuna_study.sqlite3" + + best_params, epoch_mean, study = run_optuna_cv( + full_dataset=full_dataset, + strata=strata, + device=device, + n_trials=n_trials, + epochs_per_trial=epochs_per_trial, + patience=patience, + batch_size=batch_size, + n_folds=n_folds, + use_mpnn=use_mpnn, + seed=seed, + study_path=study_path, + pretrain_state_dict=pretrain_state_dict, + pretrain_config=pretrain_config, + load_delivery_head=load_delivery_head, + ) + + # 保存最佳参数 + with open(output_dir / "best_params.json", "w") as f: + json.dump(best_params, f, indent=2) + + with open(output_dir / "epoch_mean.json", "w") as f: + json.dump({"epoch_mean": epoch_mean}, f) + + # 保存 Optuna 试验历史 + trials_history = [] + for trial in study.trials: + trials_history.append({ + "number": trial.number, + "value": trial.value, + "params": trial.params, + "user_attrs": trial.user_attrs, + "state": str(trial.state), + }) + + with open(output_dir / "optuna_trials.json", "w") as f: + json.dump(trials_history, f, indent=2) + + # 全量数据训练 + logger.info(f"\n{'='*60}") + logger.info("FINAL TRAINING ON FULL DATA") + logger.info(f"{'='*60}") + logger.info(f"Using best params with epochs={epoch_mean}") + logger.info(f"SWA: {use_swa}") + + # 创建全量 DataLoader + full_loader = DataLoader( + full_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn + ) + + # 计算类权重 + class_weights = compute_class_weights_from_loader(full_loader) + + # 保存类权重信息 + class_weights_info = { + "pdi": class_weights.pdi.tolist() if class_weights.pdi is not None else None, + "ee": class_weights.ee.tolist() if class_weights.ee is not None else None, + "toxic": class_weights.toxic.tolist() if class_weights.toxic is not None else None, + } + with open(output_dir / "class_weights.json", "w") as f: + json.dump(class_weights_info, f, indent=2) + + # 创建模型 + model = create_model( + d_model=best_params["d_model"], + num_heads=best_params["num_heads"], + n_attn_layers=best_params["n_attn_layers"], + fusion_strategy=best_params["fusion_strategy"], + head_hidden_dim=best_params["head_hidden_dim"], + dropout=best_params["dropout"], + use_mpnn=use_mpnn, + mpnn_device=device.type, + ) + + # 加载预训练权重 + if pretrain_state_dict is not None and pretrain_config is not None: + loaded = load_pretrain_weights_to_model( + model, pretrain_state_dict, best_params["d_model"], + pretrain_config, load_delivery_head + ) + if loaded: + logger.info("Loaded pretrain weights for final training") + + # 打印模型信息 + n_params_total = sum(p.numel() for p in model.parameters()) + n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable") + + # 训练(固定 epoch,不 early-stop) + swa_start = int(epoch_mean * swa_start_ratio) if use_swa else None + + train_result = train_fixed_epochs( + model=model, + train_loader=full_loader, + val_loader=None, # 全量训练,无验证集 + device=device, + lr=best_params["lr"], + weight_decay=best_params["weight_decay"], + epochs=epoch_mean, + class_weights=class_weights, + use_cosine_annealing=True, + use_swa=use_swa, + swa_start_epoch=swa_start, + ) + + # 加载最终权重 + model.load_state_dict(train_result["final_state"]) + + # 保存模型 + config = { + "d_model": best_params["d_model"], + "num_heads": best_params["num_heads"], + "n_attn_layers": best_params["n_attn_layers"], + "fusion_strategy": best_params["fusion_strategy"], + "head_hidden_dim": best_params["head_hidden_dim"], + "dropout": best_params["dropout"], + "use_mpnn": use_mpnn, + } + + torch.save({ + "model_state_dict": train_result["final_state"], + "config": config, + "best_params": best_params, + "epoch_mean": epoch_mean, + "use_swa": use_swa, + }, output_dir / "model.pt") + + logger.success(f"Saved model to {output_dir / 'model.pt'}") + + # 保存训练历史 + with open(output_dir / "history.json", "w") as f: + json.dump(train_result["history"], f, indent=2) + + # 绘制损失曲线 + if train_result["history"]["train"]: + try: + plot_multitask_loss_curves( + history=train_result["history"], + output_path=output_dir / "loss_curves.png", + title="Final Training Loss Curves", + ) + logger.info(f"Saved loss curves to {output_dir / 'loss_curves.png'}") + except Exception as e: + logger.warning(f"Failed to plot loss curves: {e}") + + # 打印最终信息 + logger.info(f"\n{'='*60}") + logger.info("FINAL TRAINING COMPLETE") + logger.info(f"{'='*60}") + logger.info(f"Best params: {best_params}") + logger.info(f"Epochs trained: {epoch_mean}") + logger.info(f"Output directory: {output_dir}") + + # 提示如何使用 + logger.info("\n[How to use the trained model]") + logger.info(f" 1. Set environment variable: MODEL_PATH={output_dir / 'model.pt'}") + logger.info(" 2. Or copy to default location: cp {output_dir}/model.pt models/final/model.pt") + logger.info(" 3. Start API: make api") + + +if __name__ == "__main__": + app() + diff --git a/lnp_ml/modeling/nested_cv_optuna.py b/lnp_ml/modeling/nested_cv_optuna.py new file mode 100644 index 0000000..e96054f --- /dev/null +++ b/lnp_ml/modeling/nested_cv_optuna.py @@ -0,0 +1,774 @@ +""" +嵌套交叉验证 + Optuna 超参调优 + +外层 5-fold StratifiedKFold(20% test / 80% train) +内层 3-fold StratifiedKFold(在 80% 上做 Optuna 超参搜索) + +使用方法: + python -m lnp_ml.modeling.nested_cv_optuna + +或通过 Makefile: + make nested_cv_tune DEVICE=cuda +""" + +import json +import math +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import DataLoader, Subset +from sklearn.model_selection import StratifiedKFold +from loguru import logger +import typer + +try: + import optuna + from optuna.samplers import TPESampler +except ImportError: + optuna = None + TPESampler = None + +from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR +from lnp_ml.dataset import ( + LNPDataset, + collate_fn, + process_dataframe, + TARGET_CLASSIFICATION_PDI, + TARGET_CLASSIFICATION_EE, + TARGET_TOXIC, +) +from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN +from lnp_ml.modeling.trainer_balanced import ( + ClassWeights, + LossWeightsBalanced, + compute_class_weights_from_loader, + train_with_early_stopping, + train_fixed_epochs, + validate_balanced, +) + +# MPNN ensemble 默认路径 +DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON" + +app = typer.Typer() + + +# ============ CompositeStrata 复合分层标签 ============ + +def build_composite_strata( + df: pd.DataFrame, + min_stratum_count: int = 5, +) -> Tuple[np.ndarray, Dict]: + """ + 构建复合分层标签(toxic × PDI × EE)。 + + Args: + df: 处理后的 DataFrame + min_stratum_count: 每个 stratum 最少样本数,低于此值合并为 RARE + + Returns: + (strata_array, strata_info) + - strata_array: 每个样本的 stratum 编码(整数) + - strata_info: 统计信息 + """ + n = len(df) + strata_labels = [] + + for i in range(n): + # Toxic stratum + if TARGET_TOXIC in df.columns: + toxic_val = df[TARGET_TOXIC].iloc[i] + if pd.notna(toxic_val) and toxic_val >= 0: + toxic_str = str(int(toxic_val)) + else: + toxic_str = "NA" + else: + toxic_str = "NA" + + # PDI stratum + if all(col in df.columns for col in TARGET_CLASSIFICATION_PDI): + pdi_vals = df[TARGET_CLASSIFICATION_PDI].iloc[i].values + if pdi_vals.sum() > 0: + pdi_str = str(int(np.argmax(pdi_vals))) + else: + pdi_str = "NA" + else: + pdi_str = "NA" + + # EE stratum + if all(col in df.columns for col in TARGET_CLASSIFICATION_EE): + ee_vals = df[TARGET_CLASSIFICATION_EE].iloc[i].values + if ee_vals.sum() > 0: + ee_str = str(int(np.argmax(ee_vals))) + else: + ee_str = "NA" + else: + ee_str = "NA" + + strata_labels.append(f"T{toxic_str}|P{pdi_str}|E{ee_str}") + + # 统计各 stratum 的样本数 + unique_strata, counts = np.unique(strata_labels, return_counts=True) + strata_counts = dict(zip(unique_strata, counts)) + + # 将稀疏 strata 合并为 RARE + rare_strata = [s for s, c in strata_counts.items() if c < min_stratum_count] + + final_labels = [] + for label in strata_labels: + if label in rare_strata: + final_labels.append("RARE") + else: + final_labels.append(label) + + # 编码为整数 + unique_final, encoded = np.unique(final_labels, return_inverse=True) + + strata_info = { + "original_strata_counts": strata_counts, + "rare_strata": rare_strata, + "final_strata": list(unique_final), + "final_strata_counts": dict(zip(*np.unique(final_labels, return_counts=True))), + "n_rare_merged": sum(strata_counts[s] for s in rare_strata) if rare_strata else 0, + } + + logger.info(f"Built composite strata: {len(unique_final)} unique strata") + logger.info(f" Rare strata merged: {len(rare_strata)} types, {strata_info['n_rare_merged']} samples") + + return encoded.astype(np.int64), strata_info + + +# ============ 模型创建 ============ + +def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]: + """自动查找 MPNN ensemble 的 model.pt 文件。""" + model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt")) + if not model_paths: + raise FileNotFoundError(f"No model.pt files found in {base_dir}") + return [str(p) for p in model_paths] + + +def create_model( + d_model: int = 256, + num_heads: int = 8, + n_attn_layers: int = 4, + fusion_strategy: str = "attention", + head_hidden_dim: int = 128, + dropout: float = 0.1, + use_mpnn: bool = False, + mpnn_device: str = "cpu", +) -> Union[LNPModel, LNPModelWithoutMPNN]: + """创建模型""" + if use_mpnn: + ensemble_paths = find_mpnn_ensemble_paths() + return LNPModel( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + mpnn_ensemble_paths=ensemble_paths, + mpnn_device=mpnn_device, + ) + else: + return LNPModelWithoutMPNN( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + ) + + +# ============ 评估指标 ============ + +def evaluate_on_test( + model: torch.nn.Module, + test_loader: DataLoader, + device: torch.device, +) -> Dict: + """在测试集上评估模型""" + from scipy.special import rel_entr + from sklearn.metrics import ( + mean_squared_error, + mean_absolute_error, + r2_score, + accuracy_score, + precision_score, + recall_score, + f1_score, + ) + + model.eval() + + preds = { + "size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": [] + } + targets = { + "size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": [] + } + + with torch.no_grad(): + for batch in test_loader: + smiles = batch["smiles"] + tabular = {k: v.to(device) for k, v in batch["tabular"].items()} + tgts = batch["targets"] + masks = batch["mask"] + + outputs = model(smiles, tabular) + + # 收集预测和真实值 + for task in ["size", "delivery"]: + if task in masks and masks[task].any(): + m = masks[task] + key = task if task == "size" else "delivery" + preds[task].extend(outputs[key].squeeze(-1)[m].cpu().numpy().tolist()) + targets[task].extend(tgts[key][m].cpu().numpy().tolist()) + + for task in ["pdi", "ee", "toxic"]: + if task in masks and masks[task].any(): + m = masks[task] + preds[task].extend(outputs[task][m].argmax(dim=-1).cpu().numpy().tolist()) + targets[task].extend(tgts[task][m].cpu().numpy().tolist()) + + if "biodist" in masks and masks["biodist"].any(): + m = masks["biodist"] + preds["biodist"].extend(outputs["biodist"][m].cpu().numpy().tolist()) + targets["biodist"].extend(tgts["biodist"][m].cpu().numpy().tolist()) + + # 计算指标 + results = {} + + # 回归任务 + for task in ["size", "delivery"]: + if preds[task]: + p = np.array(preds[task]) + t = np.array(targets[task]) + results[task] = { + "n_samples": len(p), + "mse": float(mean_squared_error(t, p)), + "rmse": float(np.sqrt(mean_squared_error(t, p))), + "mae": float(mean_absolute_error(t, p)), + "r2": float(r2_score(t, p)), + } + + # 分类任务 + for task in ["pdi", "ee", "toxic"]: + if preds[task]: + p = np.array(preds[task]) + t = np.array(targets[task]) + results[task] = { + "n_samples": len(p), + "accuracy": float(accuracy_score(t, p)), + "precision": float(precision_score(t, p, average="macro", zero_division=0)), + "recall": float(recall_score(t, p, average="macro", zero_division=0)), + "f1": float(f1_score(t, p, average="macro", zero_division=0)), + } + + # 分布任务 + if preds["biodist"]: + p = np.array(preds["biodist"]) + t = np.array(targets["biodist"]) + + def kl_divergence(p_arr, q_arr, eps=1e-10): + p_arr = np.clip(p_arr, eps, 1.0) + q_arr = np.clip(q_arr, eps, 1.0) + return float(np.sum(rel_entr(p_arr, q_arr), axis=-1).mean()) + + def js_divergence(p_arr, q_arr, eps=1e-10): + p_arr = np.clip(p_arr, eps, 1.0) + q_arr = np.clip(q_arr, eps, 1.0) + m = 0.5 * (p_arr + q_arr) + return float(0.5 * (np.sum(rel_entr(p_arr, m), axis=-1) + np.sum(rel_entr(q_arr, m), axis=-1)).mean()) + + results["biodist"] = { + "n_samples": len(p), + "kl_divergence": kl_divergence(t, p), + "js_divergence": js_divergence(t, p), + } + + return results + + +# ============ 预训练权重加载 ============ + +def load_pretrain_weights_to_model( + model: Union[LNPModel, LNPModelWithoutMPNN], + pretrain_state_dict: Dict, + d_model: int, + pretrain_config: Dict, + load_delivery_head: bool = True, +) -> bool: + """ + 加载预训练权重到模型。 + + Returns: + 是否成功加载 + """ + if pretrain_config.get("d_model") != d_model: + logger.warning( + f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, " + f"current={d_model}. Skipping pretrain loading." + ) + return False + + model.load_pretrain_weights( + pretrain_state_dict=pretrain_state_dict, + load_delivery_head=load_delivery_head, + strict=False, + ) + return True + + +# ============ 内层 Optuna 调参 ============ + +def run_inner_optuna( + full_dataset: LNPDataset, + inner_train_indices: np.ndarray, + strata: np.ndarray, + device: torch.device, + n_trials: int = 20, + epochs_per_trial: int = 30, + patience: int = 10, + batch_size: int = 32, + n_inner_folds: int = 3, + use_mpnn: bool = False, + seed: int = 42, + study_path: Optional[Path] = None, + pretrain_state_dict: Optional[Dict] = None, + pretrain_config: Optional[Dict] = None, + load_delivery_head: bool = True, +) -> Tuple[Dict, int, optuna.Study]: + """ + 在内层数据上运行 Optuna 超参搜索。 + + Args: + full_dataset: 完整数据集 + inner_train_indices: 内层训练数据的索引(相对于 full_dataset) + strata: 每个样本的分层标签 + device: 设备 + n_trials: Optuna 试验数 + epochs_per_trial: 每个试验的最大 epoch + patience: 早停耐心值 + batch_size: 批次大小 + n_inner_folds: 内层折数 + use_mpnn: 是否使用 MPNN + seed: 随机种子 + study_path: 可选的 study 持久化路径 + pretrain_state_dict: 预训练权重 + pretrain_config: 预训练配置 + load_delivery_head: 是否加载 delivery head 权重 + + Returns: + (best_params, epoch_mean, study) + """ + if optuna is None: + raise ImportError("Optuna not installed. Run: pip install optuna") + + inner_strata = strata[inner_train_indices] + + def objective(trial: optuna.Trial) -> float: + # 采样超参数 + d_model = trial.suggest_categorical("d_model", [128, 256, 512]) + num_heads = trial.suggest_categorical("num_heads", [4, 8]) + n_attn_layers = trial.suggest_int("n_attn_layers", 2, 6) + fusion_strategy = trial.suggest_categorical( + "fusion_strategy", ["attention", "avg", "max"] + ) + head_hidden_dim = trial.suggest_categorical("head_hidden_dim", [64, 128, 256]) + dropout = trial.suggest_float("dropout", 0.05, 0.3) + lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True) + weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True) + + # 内层 3-fold CV + inner_cv = StratifiedKFold( + n_splits=n_inner_folds, shuffle=True, random_state=seed + ) + + fold_val_losses = [] + fold_best_epochs = [] + + for inner_fold, (inner_train_idx, inner_val_idx) in enumerate( + inner_cv.split(inner_train_indices, inner_strata) + ): + # 获取实际的数据集索引 + actual_train_idx = inner_train_indices[inner_train_idx] + actual_val_idx = inner_train_indices[inner_val_idx] + + # 创建 DataLoader + train_subset = Subset(full_dataset, actual_train_idx.tolist()) + val_subset = Subset(full_dataset, actual_val_idx.tolist()) + + train_loader = DataLoader( + train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn + ) + val_loader = DataLoader( + val_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn + ) + + # 计算类权重 + class_weights = compute_class_weights_from_loader(train_loader) + + # 创建模型 + model = create_model( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + use_mpnn=use_mpnn, + mpnn_device=device.type, + ) + + # 加载预训练权重 + if pretrain_state_dict is not None and pretrain_config is not None: + load_pretrain_weights_to_model( + model, pretrain_state_dict, d_model, pretrain_config, load_delivery_head + ) + + # 训练(带早停) + result = train_with_early_stopping( + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + lr=lr, + weight_decay=weight_decay, + epochs=epochs_per_trial, + patience=patience, + class_weights=class_weights, + ) + + fold_val_losses.append(result["best_val_loss"]) + fold_best_epochs.append(result["best_epoch"]) + + # 记录 epoch_mean 到 trial + epoch_mean = int(round(np.mean(fold_best_epochs))) + trial.set_user_attr("epoch_mean", epoch_mean) + trial.set_user_attr("fold_best_epochs", fold_best_epochs) + + return np.mean(fold_val_losses) + + # 创建 study + storage = None + if study_path is not None: + storage = f"sqlite:///{study_path}" + + study = optuna.create_study( + direction="minimize", + sampler=TPESampler(seed=seed), + storage=storage, + study_name="inner_optuna", + load_if_exists=True, + ) + + study.optimize(objective, n_trials=n_trials, show_progress_bar=True) + + best_params = study.best_trial.params + epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial) + + logger.info(f"Best trial: {study.best_trial.number}") + logger.info(f"Best val_loss: {study.best_trial.value:.4f}") + logger.info(f"Best params: {best_params}") + logger.info(f"Epoch mean: {epoch_mean}") + + return best_params, epoch_mean, study + + +# ============ 主流程 ============ + +@app.command() +def main( + input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv", + output_dir: Path = MODELS_DIR / "nested_cv", + # CV 参数 + n_outer_folds: int = 5, + n_inner_folds: int = 3, + min_stratum_count: int = 5, + seed: int = 42, + # Optuna 参数 + n_trials: int = 20, + epochs_per_trial: int = 30, + inner_patience: int = 10, + # 训练参数 + batch_size: int = 32, + # 预训练权重 + init_from_pretrain: Optional[Path] = None, + load_delivery_head: bool = True, + # MPNN + use_mpnn: bool = False, + # 设备 + device: str = "cuda" if torch.cuda.is_available() else "cpu", +): + """ + 嵌套交叉验证 + Optuna 超参调优。 + + 外层 5-fold(20% test / 80% train),内层 3-fold Optuna 调参。 + 外层训练不使用 early-stopping,epoch 数使用内层 best trial 的 epoch_mean。 + + 使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重。 + """ + if optuna is None: + logger.error("Optuna not installed. Run: pip install optuna") + raise typer.Exit(1) + + logger.info(f"Using device: {device}") + device = torch.device(device) + + # 加载预训练权重(如果指定) + pretrain_state_dict = None + pretrain_config = None + if init_from_pretrain is not None: + if init_from_pretrain.exists(): + logger.info(f"Loading pretrain weights from {init_from_pretrain}") + checkpoint = torch.load(init_from_pretrain, map_location="cpu") + pretrain_state_dict = checkpoint["model_state_dict"] + pretrain_config = checkpoint.get("config", {}) + logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})") + else: + logger.warning(f"Pretrain checkpoint not found: {init_from_pretrain}, skipping") + + # 创建输出目录(带时间戳) + run_name = datetime.now().strftime("%Y%m%d_%H%M%S") + run_dir = output_dir / run_name + run_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Output directory: {run_dir}") + + # 加载数据 + logger.info(f"Loading data from {input_path}") + df = pd.read_csv(input_path) + logger.info(f"Loaded {len(df)} samples") + + # 处理数据 + logger.info("Processing dataframe...") + df = process_dataframe(df) + + # 构建复合分层标签 + logger.info("Building composite strata...") + strata, strata_info = build_composite_strata(df, min_stratum_count) + + # 保存 strata 信息 + with open(run_dir / "strata_info.json", "w") as f: + json.dump(strata_info, f, indent=2, default=str) + + # 创建完整数据集 + full_dataset = LNPDataset(df) + n_samples = len(full_dataset) + + # 外层 CV + outer_cv = StratifiedKFold( + n_splits=n_outer_folds, shuffle=True, random_state=seed + ) + + outer_results = [] + + for outer_fold, (outer_train_idx, outer_test_idx) in enumerate( + outer_cv.split(np.arange(n_samples), strata) + ): + logger.info(f"\n{'='*60}") + logger.info(f"OUTER FOLD {outer_fold}") + logger.info(f"{'='*60}") + logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}") + + fold_dir = run_dir / f"outer_fold_{outer_fold}" + fold_dir.mkdir(parents=True, exist_ok=True) + + # 保存 split indices + splits = { + "outer_train_idx": outer_train_idx.tolist(), + "outer_test_idx": outer_test_idx.tolist(), + } + with open(fold_dir / "splits.json", "w") as f: + json.dump(splits, f) + + # 内层 Optuna 调参 + logger.info(f"\nRunning inner Optuna with {n_trials} trials...") + study_path = fold_dir / "optuna_study.sqlite3" + + best_params, epoch_mean, study = run_inner_optuna( + full_dataset=full_dataset, + inner_train_indices=outer_train_idx, + strata=strata, + device=device, + n_trials=n_trials, + epochs_per_trial=epochs_per_trial, + patience=inner_patience, + batch_size=batch_size, + n_inner_folds=n_inner_folds, + use_mpnn=use_mpnn, + seed=seed + outer_fold, + study_path=study_path, + pretrain_state_dict=pretrain_state_dict, + pretrain_config=pretrain_config, + load_delivery_head=load_delivery_head, + ) + + # 保存最佳参数 + with open(fold_dir / "best_params.json", "w") as f: + json.dump(best_params, f, indent=2) + + with open(fold_dir / "epoch_mean.json", "w") as f: + json.dump({"epoch_mean": epoch_mean}, f) + + # 外层训练(使用最优超参,固定 epoch 数,不 early-stop) + logger.info(f"\nTraining outer fold with best params, epochs={epoch_mean}...") + + # 创建 DataLoader + train_subset = Subset(full_dataset, outer_train_idx.tolist()) + test_subset = Subset(full_dataset, outer_test_idx.tolist()) + + train_loader = DataLoader( + train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn + ) + test_loader = DataLoader( + test_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn + ) + + # 计算类权重 + class_weights = compute_class_weights_from_loader(train_loader) + + # 创建模型 + model = create_model( + d_model=best_params["d_model"], + num_heads=best_params["num_heads"], + n_attn_layers=best_params["n_attn_layers"], + fusion_strategy=best_params["fusion_strategy"], + head_hidden_dim=best_params["head_hidden_dim"], + dropout=best_params["dropout"], + use_mpnn=use_mpnn, + mpnn_device=device.type, + ) + + # 加载预训练权重 + if pretrain_state_dict is not None and pretrain_config is not None: + loaded = load_pretrain_weights_to_model( + model, pretrain_state_dict, best_params["d_model"], + pretrain_config, load_delivery_head + ) + if loaded: + logger.info(f"Loaded pretrain weights for outer fold {outer_fold}") + + # 训练(固定 epoch,不 early-stop) + train_result = train_fixed_epochs( + model=model, + train_loader=train_loader, + val_loader=None, # 外层不用验证集 + device=device, + lr=best_params["lr"], + weight_decay=best_params["weight_decay"], + epochs=epoch_mean, + class_weights=class_weights, + use_cosine_annealing=True, + ) + + # 加载最终权重 + model.load_state_dict(train_result["final_state"]) + model = model.to(device) + + # 保存模型 + config = { + "d_model": best_params["d_model"], + "num_heads": best_params["num_heads"], + "n_attn_layers": best_params["n_attn_layers"], + "fusion_strategy": best_params["fusion_strategy"], + "head_hidden_dim": best_params["head_hidden_dim"], + "dropout": best_params["dropout"], + "use_mpnn": use_mpnn, + } + + torch.save({ + "model_state_dict": train_result["final_state"], + "config": config, + "epoch_mean": epoch_mean, + "best_params": best_params, + }, fold_dir / "model.pt") + + # 保存训练历史 + with open(fold_dir / "history.json", "w") as f: + json.dump(train_result["history"], f, indent=2) + + # 在测试集上评估 + logger.info("Evaluating on outer test set...") + test_metrics = evaluate_on_test(model, test_loader, device) + + with open(fold_dir / "test_metrics.json", "w") as f: + json.dump(test_metrics, f, indent=2) + + # 打印测试结果 + logger.info(f"\nOuter Fold {outer_fold} Test Results:") + for task, metrics in test_metrics.items(): + if "rmse" in metrics: + logger.info(f" {task}: RMSE={metrics['rmse']:.4f}, R²={metrics['r2']:.4f}") + elif "accuracy" in metrics: + logger.info(f" {task}: Acc={metrics['accuracy']:.4f}, F1={metrics['f1']:.4f}") + elif "kl_divergence" in metrics: + logger.info(f" {task}: KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}") + + outer_results.append({ + "fold": outer_fold, + "best_params": best_params, + "epoch_mean": epoch_mean, + "test_metrics": test_metrics, + }) + + # 汇总结果 + logger.info("\n" + "=" * 60) + logger.info("NESTED CV COMPLETE") + logger.info("=" * 60) + + # 计算汇总统计 + summary = {"fold_results": outer_results} + + # 对每个任务计算均值和标准差 + tasks_with_metrics = {} + for result in outer_results: + for task, metrics in result["test_metrics"].items(): + if task not in tasks_with_metrics: + tasks_with_metrics[task] = {k: [] for k in metrics.keys() if k != "n_samples"} + for k, v in metrics.items(): + if k != "n_samples": + tasks_with_metrics[task][k].append(v) + + summary["summary_stats"] = {} + for task, metrics_dict in tasks_with_metrics.items(): + summary["summary_stats"][task] = {} + for metric_name, values in metrics_dict.items(): + summary["summary_stats"][task][f"{metric_name}_mean"] = float(np.mean(values)) + summary["summary_stats"][task][f"{metric_name}_std"] = float(np.std(values)) + + # 打印汇总 + logger.info("\n[Summary Statistics]") + for task, stats in summary["summary_stats"].items(): + if "rmse_mean" in stats: + logger.info( + f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, " + f"R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}" + ) + elif "accuracy_mean" in stats: + logger.info( + f" {task}: Acc={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}, " + f"F1={stats['f1_mean']:.4f}±{stats['f1_std']:.4f}" + ) + elif "kl_divergence_mean" in stats: + logger.info( + f" {task}: KL={stats['kl_divergence_mean']:.4f}±{stats['kl_divergence_std']:.4f}, " + f"JS={stats['js_divergence_mean']:.4f}±{stats['js_divergence_std']:.4f}" + ) + + # 保存汇总 + with open(run_dir / "summary.json", "w") as f: + json.dump(summary, f, indent=2) + + logger.success(f"\nAll results saved to {run_dir}") + + +if __name__ == "__main__": + app() + diff --git a/lnp_ml/modeling/trainer_balanced.py b/lnp_ml/modeling/trainer_balanced.py new file mode 100644 index 0000000..bb2b803 --- /dev/null +++ b/lnp_ml/modeling/trainer_balanced.py @@ -0,0 +1,474 @@ +"""带类权重的训练器:处理分类任务的数据不均衡问题""" + +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from tqdm import tqdm + + +@dataclass +class ClassWeights: + """分类任务的类权重""" + pdi: Optional[torch.Tensor] = None # [4] for 4 PDI classes + ee: Optional[torch.Tensor] = None # [3] for 3 EE classes + toxic: Optional[torch.Tensor] = None # [2] for binary toxic + + +@dataclass +class LossWeightsBalanced: + """各任务的损失权重(与 trainer.py 的 LossWeights 兼容)""" + size: float = 1.0 + pdi: float = 1.0 + ee: float = 1.0 + delivery: float = 1.0 + biodist: float = 1.0 + toxic: float = 1.0 + + +def compute_class_weights_from_loader( + loader: DataLoader, + n_pdi_classes: int = 4, + n_ee_classes: int = 3, + n_toxic_classes: int = 2, + smoothing: float = 0.1, +) -> ClassWeights: + """ + 从 DataLoader 统计类别频次并计算类权重。 + + 使用 inverse frequency 方式:weight_c = N / (n_classes * count_c) + 加 smoothing 避免极端权重。 + + Args: + loader: DataLoader(需要遍历一次) + n_pdi_classes: PDI 类别数 + n_ee_classes: EE 类别数 + n_toxic_classes: toxic 类别数 + smoothing: 平滑系数(防止除零和极端权重) + + Returns: + ClassWeights 对象,包含各分类任务的类权重张量 + """ + pdi_counts = torch.zeros(n_pdi_classes) + ee_counts = torch.zeros(n_ee_classes) + toxic_counts = torch.zeros(n_toxic_classes) + + for batch in loader: + targets = batch["targets"] + mask = batch["mask"] + + # PDI + if "pdi" in targets and "pdi" in mask: + m = mask["pdi"] + if m.any(): + labels = targets["pdi"][m] + for c in range(n_pdi_classes): + pdi_counts[c] += (labels == c).sum().item() + + # EE + if "ee" in targets and "ee" in mask: + m = mask["ee"] + if m.any(): + labels = targets["ee"][m] + for c in range(n_ee_classes): + ee_counts[c] += (labels == c).sum().item() + + # Toxic + if "toxic" in targets and "toxic" in mask: + m = mask["toxic"] + if m.any(): + labels = targets["toxic"][m] + for c in range(n_toxic_classes): + toxic_counts[c] += (labels == c).sum().item() + + def counts_to_weights(counts: torch.Tensor, n_classes: int) -> Optional[torch.Tensor]: + """将计数转换为类权重""" + total = counts.sum().item() + if total == 0: + return None + # Inverse frequency with smoothing + counts = counts + smoothing + weights = total / (n_classes * counts) + # Normalize to mean=1 + weights = weights / weights.mean() + return weights + + return ClassWeights( + pdi=counts_to_weights(pdi_counts, n_pdi_classes), + ee=counts_to_weights(ee_counts, n_ee_classes), + toxic=counts_to_weights(toxic_counts, n_toxic_classes), + ) + + +def compute_multitask_loss_balanced( + outputs: Dict[str, torch.Tensor], + targets: Dict[str, torch.Tensor], + mask: Dict[str, torch.Tensor], + task_weights: Optional[LossWeightsBalanced] = None, + class_weights: Optional[ClassWeights] = None, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + 计算带类权重的多任务损失。 + + Args: + outputs: 模型输出 + targets: 真实标签 + mask: 有效样本掩码 + task_weights: 各任务权重 + class_weights: 分类任务的类权重 + + Returns: + (total_loss, loss_dict) 总损失和各任务损失 + """ + task_weights = task_weights or LossWeightsBalanced() + class_weights = class_weights or ClassWeights() + + losses = {} + device = next(iter(outputs.values())).device + total_loss = torch.tensor(0.0, device=device) + + # size: MSE loss(回归任务,不需要类权重) + if "size" in targets and mask["size"].any(): + m = mask["size"] + pred = outputs["size"][m].squeeze(-1) + tgt = targets["size"][m] + losses["size"] = F.mse_loss(pred, tgt) + total_loss = total_loss + task_weights.size * losses["size"] + + # delivery: MSE loss(回归任务,不需要类权重) + if "delivery" in targets and mask["delivery"].any(): + m = mask["delivery"] + pred = outputs["delivery"][m].squeeze(-1) + tgt = targets["delivery"][m] + losses["delivery"] = F.mse_loss(pred, tgt) + total_loss = total_loss + task_weights.delivery * losses["delivery"] + + # pdi: CrossEntropy with class weights + if "pdi" in targets and mask["pdi"].any(): + m = mask["pdi"] + pred = outputs["pdi"][m] + tgt = targets["pdi"][m] + weight = class_weights.pdi.to(device) if class_weights.pdi is not None else None + losses["pdi"] = F.cross_entropy(pred, tgt, weight=weight) + total_loss = total_loss + task_weights.pdi * losses["pdi"] + + # ee: CrossEntropy with class weights + if "ee" in targets and mask["ee"].any(): + m = mask["ee"] + pred = outputs["ee"][m] + tgt = targets["ee"][m] + weight = class_weights.ee.to(device) if class_weights.ee is not None else None + losses["ee"] = F.cross_entropy(pred, tgt, weight=weight) + total_loss = total_loss + task_weights.ee * losses["ee"] + + # toxic: CrossEntropy with class weights + if "toxic" in targets and mask["toxic"].any(): + m = mask["toxic"] + pred = outputs["toxic"][m] + tgt = targets["toxic"][m] + weight = class_weights.toxic.to(device) if class_weights.toxic is not None else None + losses["toxic"] = F.cross_entropy(pred, tgt, weight=weight) + total_loss = total_loss + task_weights.toxic * losses["toxic"] + + # biodist: KL divergence(分布任务,不需要类权重) + if "biodist" in targets and mask["biodist"].any(): + m = mask["biodist"] + pred = outputs["biodist"][m] + tgt = targets["biodist"][m] + losses["biodist"] = F.kl_div( + pred.log().clamp(min=-100), + tgt, + reduction="batchmean", + ) + total_loss = total_loss + task_weights.biodist * losses["biodist"] + + return total_loss, losses + + +def train_epoch_balanced( + model: nn.Module, + loader: DataLoader, + optimizer: torch.optim.Optimizer, + device: torch.device, + task_weights: Optional[LossWeightsBalanced] = None, + class_weights: Optional[ClassWeights] = None, +) -> Dict[str, float]: + """带类权重的训练一个 epoch""" + model.train() + total_loss = 0.0 + task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]} + n_batches = 0 + + for batch in tqdm(loader, desc="Training", leave=False): + smiles = batch["smiles"] + tabular = {k: v.to(device) for k, v in batch["tabular"].items()} + targets = {k: v.to(device) for k, v in batch["targets"].items()} + mask = {k: v.to(device) for k, v in batch["mask"].items()} + + optimizer.zero_grad() + outputs = model(smiles, tabular) + loss, losses = compute_multitask_loss_balanced( + outputs, targets, mask, task_weights, class_weights + ) + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + + total_loss += loss.item() + for k, v in losses.items(): + task_losses[k] += v.item() + n_batches += 1 + + return { + "loss": total_loss / n_batches, + **{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0}, + } + + +@torch.no_grad() +def validate_balanced( + model: nn.Module, + loader: DataLoader, + device: torch.device, + task_weights: Optional[LossWeightsBalanced] = None, + class_weights: Optional[ClassWeights] = None, +) -> Dict[str, float]: + """带类权重的验证""" + model.eval() + total_loss = 0.0 + task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]} + n_batches = 0 + + # 用于计算准确率 + correct = {k: 0 for k in ["pdi", "ee", "toxic"]} + total = {k: 0 for k in ["pdi", "ee", "toxic"]} + + for batch in tqdm(loader, desc="Validating", leave=False): + smiles = batch["smiles"] + tabular = {k: v.to(device) for k, v in batch["tabular"].items()} + targets = {k: v.to(device) for k, v in batch["targets"].items()} + mask = {k: v.to(device) for k, v in batch["mask"].items()} + + outputs = model(smiles, tabular) + loss, losses = compute_multitask_loss_balanced( + outputs, targets, mask, task_weights, class_weights + ) + + total_loss += loss.item() + for k, v in losses.items(): + task_losses[k] += v.item() + n_batches += 1 + + # 计算分类准确率 + for k in ["pdi", "ee", "toxic"]: + if k in targets and mask[k].any(): + m = mask[k] + pred = outputs[k][m].argmax(dim=-1) + tgt = targets[k][m] + correct[k] += (pred == tgt).sum().item() + total[k] += m.sum().item() + + metrics = { + "loss": total_loss / n_batches, + **{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0}, + } + + # 添加准确率 + for k in ["pdi", "ee", "toxic"]: + if total[k] > 0: + metrics[f"acc_{k}"] = correct[k] / total[k] + + return metrics + + +class EarlyStoppingBalanced: + """早停机制(与 trainer.py 的 EarlyStopping 兼容)""" + + def __init__(self, patience: int = 10, min_delta: float = 0.0): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.best_loss = float("inf") + self.best_epoch = 0 + self.should_stop = False + + def __call__(self, val_loss: float, epoch: int = 0) -> bool: + if val_loss < self.best_loss - self.min_delta: + self.best_loss = val_loss + self.best_epoch = epoch + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + self.should_stop = True + return self.should_stop + + def get_best_epoch(self) -> int: + """获取最佳 epoch(1-indexed)""" + return self.best_epoch + 1 + + +def train_with_early_stopping( + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + device: torch.device, + lr: float = 1e-4, + weight_decay: float = 1e-5, + epochs: int = 100, + patience: int = 15, + task_weights: Optional[LossWeightsBalanced] = None, + class_weights: Optional[ClassWeights] = None, +) -> Dict: + """ + 带早停的完整训练流程。 + + Returns: + Dict with keys: history, best_val_loss, best_epoch, best_state + """ + model = model.to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + early_stopping = EarlyStoppingBalanced(patience=patience) + + history = {"train": [], "val": []} + best_val_loss = float("inf") + best_state = None + + for epoch in range(epochs): + # Train + train_metrics = train_epoch_balanced( + model, train_loader, optimizer, device, task_weights, class_weights + ) + + # Validate + val_metrics = validate_balanced( + model, val_loader, device, task_weights, class_weights + ) + + history["train"].append(train_metrics) + history["val"].append(val_metrics) + + # Learning rate scheduling + scheduler.step(val_metrics["loss"]) + + # Save best model + if val_metrics["loss"] < best_val_loss: + best_val_loss = val_metrics["loss"] + best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + + # Early stopping + if early_stopping(val_metrics["loss"], epoch): + break + + # Restore best model + if best_state is not None: + model.load_state_dict(best_state) + + return { + "history": history, + "best_val_loss": best_val_loss, + "best_epoch": early_stopping.get_best_epoch(), + "best_state": best_state, + "epochs_trained": len(history["train"]), + } + + +def train_fixed_epochs( + model: nn.Module, + train_loader: DataLoader, + val_loader: Optional[DataLoader], + device: torch.device, + lr: float = 1e-4, + weight_decay: float = 1e-5, + epochs: int = 50, + task_weights: Optional[LossWeightsBalanced] = None, + class_weights: Optional[ClassWeights] = None, + use_cosine_annealing: bool = True, + use_swa: bool = False, + swa_start_epoch: Optional[int] = None, +) -> Dict: + """ + 固定 epoch 数的训练(不使用 early stopping)。 + + 用于外层 CV 训练和最终训练。 + + Args: + model: 模型 + train_loader: 训练数据 + val_loader: 验证数据(可选,仅用于监控) + device: 设备 + lr: 学习率 + weight_decay: 权重衰减 + epochs: 训练轮数 + task_weights: 任务权重 + class_weights: 类权重 + use_cosine_annealing: 是否使用 CosineAnnealingLR + use_swa: 是否使用 SWA + swa_start_epoch: SWA 开始的 epoch(默认为 epochs * 0.75) + + Returns: + Dict with keys: history, final_state + """ + model = model.to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + + if use_cosine_annealing: + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + else: + scheduler = None + + # SWA setup + swa_model = None + swa_scheduler = None + if use_swa: + from torch.optim.swa_utils import AveragedModel, SWALR + swa_model = AveragedModel(model) + swa_start = swa_start_epoch or int(epochs * 0.75) + swa_scheduler = SWALR(optimizer, swa_lr=lr * 0.1) + + history = {"train": [], "val": []} + + for epoch in range(epochs): + # Train + train_metrics = train_epoch_balanced( + model, train_loader, optimizer, device, task_weights, class_weights + ) + history["train"].append(train_metrics) + + # Validate (optional) + if val_loader is not None: + val_metrics = validate_balanced( + model, val_loader, device, task_weights, class_weights + ) + history["val"].append(val_metrics) + + # Scheduler step + if use_swa and epoch >= swa_start: + swa_model.update_parameters(model) + swa_scheduler.step() + elif scheduler is not None: + scheduler.step() + + # Finalize SWA + final_state = None + if use_swa and swa_model is not None: + # Update batch normalization statistics + torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device) + final_state = {k: v.cpu().clone() for k, v in swa_model.module.state_dict().items()} + else: + final_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + + return { + "history": history, + "final_state": final_state, + "epochs_trained": epochs, + } + diff --git a/models/final/model.pt b/models/final/model.pt index 0626353..00cd556 100644 Binary files a/models/final/model.pt and b/models/final/model.pt differ diff --git a/models/finetune_cv/fold_0/model.pt b/models/finetune_cv/fold_0/model.pt index 962a403..a0c0f61 100644 Binary files a/models/finetune_cv/fold_0/model.pt and b/models/finetune_cv/fold_0/model.pt differ diff --git a/models/finetune_cv/fold_1/model.pt b/models/finetune_cv/fold_1/model.pt index 1e89fc1..32fa929 100644 Binary files a/models/finetune_cv/fold_1/model.pt and b/models/finetune_cv/fold_1/model.pt differ diff --git a/models/finetune_cv/fold_2/model.pt b/models/finetune_cv/fold_2/model.pt index 6a15f9e..e6b183c 100644 Binary files a/models/finetune_cv/fold_2/model.pt and b/models/finetune_cv/fold_2/model.pt differ diff --git a/models/finetune_cv/fold_3/model.pt b/models/finetune_cv/fold_3/model.pt index 5f3a4b9..e93eaff 100644 Binary files a/models/finetune_cv/fold_3/model.pt and b/models/finetune_cv/fold_3/model.pt differ diff --git a/models/finetune_cv/fold_4/model.pt b/models/finetune_cv/fold_4/model.pt index f33ca81..1d902d3 100644 Binary files a/models/finetune_cv/fold_4/model.pt and b/models/finetune_cv/fold_4/model.pt differ diff --git a/models/model.pt b/models/model.pt index 7023800..50c124f 100644 Binary files a/models/model.pt and b/models/model.pt differ diff --git a/models/mpnn/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt b/models/mpnn/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt index f830a1b..b8a65fb 100644 Binary files a/models/mpnn/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt and b/models/mpnn/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt differ diff --git a/models/mpnn/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt b/models/mpnn/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt index ef8c294..97ac616 100644 Binary files a/models/mpnn/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt and b/models/mpnn/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt differ diff --git a/models/mpnn/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt b/models/mpnn/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt index f97a0c6..6d6dc04 100644 Binary files a/models/mpnn/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt and b/models/mpnn/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt differ diff --git a/models/mpnn/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt b/models/mpnn/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt index 3f54b88..95ae8c9 100644 Binary files a/models/mpnn/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt and b/models/mpnn/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt differ diff --git a/models/mpnn/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt b/models/mpnn/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt index 3bfbe0a..c6393be 100644 Binary files a/models/mpnn/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt and b/models/mpnn/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt differ diff --git a/models/nested_cv/20260130_183653/outer_fold_0/best_params.json b/models/nested_cv/20260130_183653/outer_fold_0/best_params.json new file mode 100644 index 0000000..4982ff7 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_0/best_params.json @@ -0,0 +1,10 @@ +{ + "d_model": 512, + "num_heads": 8, + "n_attn_layers": 5, + "fusion_strategy": "attention", + "head_hidden_dim": 128, + "dropout": 0.14666736316838325, + "lr": 0.0001295888795454003, + "weight_decay": 7.732380983243132e-05 +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_0/epoch_mean.json b/models/nested_cv/20260130_183653/outer_fold_0/epoch_mean.json new file mode 100644 index 0000000..0ac5b47 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_0/epoch_mean.json @@ -0,0 +1 @@ +{"epoch_mean": 13} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_0/history.json b/models/nested_cv/20260130_183653/outer_fold_0/history.json new file mode 100644 index 0000000..01d655a --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_0/history.json @@ -0,0 +1,122 @@ +{ + "train": [ + { + "loss": 9.688567074862393, + "loss_size": 4.525775753638961, + "loss_pdi": 1.2179414359006016, + "loss_ee": 1.091066924008456, + "loss_delivery": 1.1003740294413134, + "loss_biodist": 1.045865768736059, + "loss_toxic": 0.7075429206544702 + }, + { + "loss": 4.477882081812078, + "loss_size": 0.30304074490612204, + "loss_pdi": 0.9405507716265592, + "loss_ee": 0.9982529336755926, + "loss_delivery": 1.1088530285791918, + "loss_biodist": 0.6725774082270536, + "loss_toxic": 0.45460730520161713 + }, + { + "loss": 3.7592489285902544, + "loss_size": 0.3136644837531177, + "loss_pdi": 0.7663570263169028, + "loss_ee": 0.9424018914049322, + "loss_delivery": 0.9587065902623263, + "loss_biodist": 0.5116219466382806, + "loss_toxic": 0.2664969251914458 + }, + { + "loss": 3.2862942435524682, + "loss_size": 0.25651484592394397, + "loss_pdi": 0.729118591005152, + "loss_ee": 0.8862769928845492, + "loss_delivery": 0.8572801744396036, + "loss_biodist": 0.42606663974848663, + "loss_toxic": 0.13103695366192947 + }, + { + "loss": 3.032014153220437, + "loss_size": 0.2605769471688704, + "loss_pdi": 0.6875471039251848, + "loss_ee": 0.8450987447391857, + "loss_delivery": 0.7942117723551664, + "loss_biodist": 0.3333369114182212, + "loss_toxic": 0.11124271154403687 + }, + { + "loss": 2.747604175047441, + "loss_size": 0.23051185499538074, + "loss_pdi": 0.6304752420295369, + "loss_ee": 0.8027611483227123, + "loss_delivery": 0.6954045072197914, + "loss_biodist": 0.31399699503725226, + "loss_toxic": 0.07445447621020404 + }, + { + "loss": 2.496532602743669, + "loss_size": 0.21782836656678806, + "loss_pdi": 0.6259297322143208, + "loss_ee": 0.7789664485237815, + "loss_delivery": 0.5751941861076788, + "loss_biodist": 0.25595076788555493, + "loss_toxic": 0.04266304531219331 + }, + { + "loss": 2.474623289975253, + "loss_size": 0.2284239645708691, + "loss_pdi": 0.6423451900482178, + "loss_ee": 0.7340371066873724, + "loss_delivery": 0.5827173766764727, + "loss_biodist": 0.24910570003769614, + "loss_toxic": 0.03799399394880642 + }, + { + "loss": 2.3557864644310693, + "loss_size": 0.197183218869296, + "loss_pdi": 0.5410721220753409, + "loss_ee": 0.7328030331568285, + "loss_delivery": 0.6107787185094573, + "loss_biodist": 0.2399107339707288, + "loss_toxic": 0.03403859864920378 + }, + { + "loss": 2.187597805803472, + "loss_size": 0.19923549348657782, + "loss_pdi": 0.514803715727546, + "loss_ee": 0.7146885395050049, + "loss_delivery": 0.49052428555759514, + "loss_biodist": 0.22473187744617462, + "loss_toxic": 0.043613933196121994 + }, + { + "loss": 2.1395224874669854, + "loss_size": 0.20194762064652008, + "loss_pdi": 0.5207281329415061, + "loss_ee": 0.6976469484242526, + "loss_delivery": 0.4825741987336766, + "loss_biodist": 0.21454968913034958, + "loss_toxic": 0.022075910629196602 + }, + { + "loss": 2.0708962462165137, + "loss_size": 0.18447093801064926, + "loss_pdi": 0.52463944662701, + "loss_ee": 0.6666911081834273, + "loss_delivery": 0.4482487521388314, + "loss_biodist": 0.21374160457741131, + "loss_toxic": 0.033104426227509975 + }, + { + "loss": 2.065277782353488, + "loss_size": 0.17790686813267795, + "loss_pdi": 0.5256167894059961, + "loss_ee": 0.6798818870024248, + "loss_delivery": 0.43902007693594153, + "loss_biodist": 0.21479454501108688, + "loss_toxic": 0.02805758648636666 + } + ], + "val": [] +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_0/model.pt b/models/nested_cv/20260130_183653/outer_fold_0/model.pt new file mode 100644 index 0000000..ee66c7b --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_0/model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04f3103a630de23aeb971629aa769e1a9cdb3c5247c193eefb808c6a4e17b9cc +size 133133308 diff --git a/models/nested_cv/20260130_183653/outer_fold_0/optuna_study.sqlite3 b/models/nested_cv/20260130_183653/outer_fold_0/optuna_study.sqlite3 new file mode 100644 index 0000000..e6c01ef Binary files /dev/null and b/models/nested_cv/20260130_183653/outer_fold_0/optuna_study.sqlite3 differ diff --git a/models/nested_cv/20260130_183653/outer_fold_0/splits.json b/models/nested_cv/20260130_183653/outer_fold_0/splits.json new file mode 100644 index 0000000..73d0024 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_0/splits.json @@ -0,0 +1 @@ +{"outer_train_idx": [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 32, 33, 34, 35, 37, 38, 39, 40, 41, 43, 44, 45, 46, 47, 48, 49, 50, 52, 53, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 102, 103, 104, 105, 106, 107, 108, 109, 110, 112, 113, 114, 115, 117, 118, 119, 121, 123, 126, 127, 128, 129, 131, 132, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 146, 147, 148, 149, 150, 152, 153, 155, 156, 158, 159, 163, 167, 169, 170, 171, 172, 173, 175, 176, 177, 178, 179, 180, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 207, 208, 211, 214, 215, 216, 217, 218, 219, 221, 222, 223, 224, 227, 228, 229, 230, 231, 233, 234, 236, 237, 238, 239, 243, 244, 245, 247, 248, 249, 251, 252, 255, 256, 257, 258, 260, 261, 262, 263, 264, 265, 266, 268, 269, 270, 271, 272, 273, 274, 275, 276, 279, 282, 283, 284, 286, 287, 288, 290, 291, 292, 293, 294, 295, 297, 298, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 319, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 335, 337, 338, 339, 340, 341, 343, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 358, 359, 360, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 379, 380, 381, 382, 383, 384, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 398, 399, 400, 402, 403, 404, 408, 409, 410, 411, 412, 413, 415, 416, 417, 418, 419, 420, 421, 422, 423, 425, 426, 428, 429, 432, 433], "outer_test_idx": [7, 12, 13, 25, 31, 36, 42, 51, 54, 55, 60, 72, 76, 80, 101, 111, 116, 120, 122, 124, 125, 130, 133, 145, 151, 154, 157, 160, 161, 162, 164, 165, 166, 168, 174, 181, 206, 209, 210, 212, 213, 220, 225, 226, 232, 235, 240, 241, 242, 246, 250, 253, 254, 259, 267, 277, 278, 280, 281, 285, 289, 296, 299, 300, 318, 320, 321, 334, 336, 342, 344, 356, 357, 361, 377, 378, 385, 397, 401, 405, 406, 407, 414, 424, 427, 430, 431]} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_0/test_metrics.json b/models/nested_cv/20260130_183653/outer_fold_0/test_metrics.json new file mode 100644 index 0000000..e818980 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_0/test_metrics.json @@ -0,0 +1,42 @@ +{ + "size": { + "n_samples": 87, + "mse": 0.26366087871128757, + "rmse": 0.5134791901443403, + "mae": 0.25157783223294666, + "r2": 0.21208517006410577 + }, + "delivery": { + "n_samples": 61, + "mse": 0.40443344562739025, + "rmse": 0.63595082013265, + "mae": 0.3928790920429298, + "r2": 0.2300258531372983 + }, + "pdi": { + "n_samples": 87, + "accuracy": 0.7241379310344828, + "precision": 0.35141509433962265, + "recall": 0.35351966873706003, + "f1": 0.348405985686402 + }, + "ee": { + "n_samples": 87, + "accuracy": 0.6666666666666666, + "precision": 0.6188811188811189, + "recall": 0.6375291375291375, + "f1": 0.6217948717948718 + }, + "toxic": { + "n_samples": 62, + "accuracy": 0.967741935483871, + "precision": 0.8, + "recall": 0.9830508474576272, + "f1": 0.8663793103448275 + }, + "biodist": { + "n_samples": 61, + "kl_divergence": 0.14776465556036145, + "js_divergence": 0.03926150329301917 + } +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_1/best_params.json b/models/nested_cv/20260130_183653/outer_fold_1/best_params.json new file mode 100644 index 0000000..5e792b1 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_1/best_params.json @@ -0,0 +1,10 @@ +{ + "d_model": 512, + "num_heads": 4, + "n_attn_layers": 2, + "fusion_strategy": "avg", + "head_hidden_dim": 64, + "dropout": 0.05188345993471756, + "lr": 4.21188892865021e-05, + "weight_decay": 4.086499445232577e-05 +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_1/epoch_mean.json b/models/nested_cv/20260130_183653/outer_fold_1/epoch_mean.json new file mode 100644 index 0000000..65f6bc1 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_1/epoch_mean.json @@ -0,0 +1 @@ +{"epoch_mean": 23} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_1/history.json b/models/nested_cv/20260130_183653/outer_fold_1/history.json new file mode 100644 index 0000000..21b973a --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_1/history.json @@ -0,0 +1,212 @@ +{ + "train": [ + { + "loss": 19.564044865694914, + "loss_size": 14.219423033974387, + "loss_pdi": 1.3700515248558738, + "loss_ee": 1.0948690609498457, + "loss_delivery": 0.925118706443093, + "loss_biodist": 1.2897993542931296, + "loss_toxic": 0.6647828817367554 + }, + { + "loss": 9.569591695612127, + "loss_size": 4.443219217387113, + "loss_pdi": 1.3034031066027554, + "loss_ee": 1.0691871643066406, + "loss_delivery": 0.986666976050897, + "loss_biodist": 1.1807570999318904, + "loss_toxic": 0.5863581760363146 + }, + { + "loss": 5.09315148266879, + "loss_size": 0.4028705805540085, + "loss_pdi": 1.2596685236150569, + "loss_ee": 1.0527758598327637, + "loss_delivery": 0.8668525652451948, + "loss_biodist": 1.0210744521834634, + "loss_toxic": 0.48990951072085986 + }, + { + "loss": 4.664039200002497, + "loss_size": 0.26930826157331467, + "loss_pdi": 1.1607397361235186, + "loss_ee": 1.0082263296300715, + "loss_delivery": 0.9969787990505045, + "loss_biodist": 0.8467015407302163, + "loss_toxic": 0.3820846405896274 + }, + { + "loss": 4.224615703929555, + "loss_size": 0.28240104629234836, + "loss_pdi": 1.048582375049591, + "loss_ee": 0.9691753116520968, + "loss_delivery": 0.8995198593898253, + "loss_biodist": 0.7173566547307101, + "loss_toxic": 0.3075804940678857 + }, + { + "loss": 3.810542041605169, + "loss_size": 0.23119159516963092, + "loss_pdi": 0.9792777408253063, + "loss_ee": 0.9330252029679038, + "loss_delivery": 0.8217983476140283, + "loss_biodist": 0.6248252906582572, + "loss_toxic": 0.22042379054156216 + }, + { + "loss": 3.5365114862268623, + "loss_size": 0.21511441333727402, + "loss_pdi": 0.9201341759074818, + "loss_ee": 0.9007513360543684, + "loss_delivery": 0.8033261312679811, + "loss_biodist": 0.5451555441726338, + "loss_toxic": 0.15202991325746884 + }, + { + "loss": 3.405807386745106, + "loss_size": 0.20161619240587408, + "loss_pdi": 0.8445044376633384, + "loss_ee": 0.8931786905635487, + "loss_delivery": 0.831847377798774, + "loss_biodist": 0.49607704173434863, + "loss_toxic": 0.13858359239318155 + }, + { + "loss": 3.103480577468872, + "loss_size": 0.18636523119427942, + "loss_pdi": 0.8065886064009233, + "loss_ee": 0.8574360067194159, + "loss_delivery": 0.6954484595493837, + "loss_biodist": 0.45561750639568677, + "loss_toxic": 0.10202483257109468 + }, + { + "loss": 3.0312788052992388, + "loss_size": 0.19699390774423425, + "loss_pdi": 0.7626179348338734, + "loss_ee": 0.8408515670082786, + "loss_delivery": 0.7123599113388495, + "loss_biodist": 0.41937256130305206, + "loss_toxic": 0.09908294406804172 + }, + { + "loss": 2.8196144104003906, + "loss_size": 0.1769183948636055, + "loss_pdi": 0.717121189290827, + "loss_ee": 0.8330178802663629, + "loss_delivery": 0.6163532761010256, + "loss_biodist": 0.3913883079182018, + "loss_toxic": 0.08481534672054378 + }, + { + "loss": 2.7494440295479516, + "loss_size": 0.17335935140197928, + "loss_pdi": 0.7235767028548501, + "loss_ee": 0.812805939804424, + "loss_delivery": 0.5859575948931954, + "loss_biodist": 0.38201813806187024, + "loss_toxic": 0.07172633301128041 + }, + { + "loss": 2.6807472705841064, + "loss_size": 0.20867897028272803, + "loss_pdi": 0.6913418119603937, + "loss_ee": 0.8103034550493414, + "loss_delivery": 0.5472286993807013, + "loss_biodist": 0.34767021103338763, + "loss_toxic": 0.07552417537028139 + }, + { + "loss": 2.5547089793465356, + "loss_size": 0.17158158326690848, + "loss_pdi": 0.6350900205698881, + "loss_ee": 0.7878076163205233, + "loss_delivery": 0.5484779531305487, + "loss_biodist": 0.34313417022878473, + "loss_toxic": 0.06861767376011069 + }, + { + "loss": 2.557967185974121, + "loss_size": 0.14963022822683508, + "loss_pdi": 0.6706653941761364, + "loss_ee": 0.7864666526967828, + "loss_delivery": 0.5519421967593107, + "loss_biodist": 0.3387457403269681, + "loss_toxic": 0.06051696803082119 + }, + { + "loss": 2.5365889939394863, + "loss_size": 0.17580546641891653, + "loss_pdi": 0.6637366847558455, + "loss_ee": 0.7831195484508168, + "loss_delivery": 0.5293790704824708, + "loss_biodist": 0.3297005011276765, + "loss_toxic": 0.054847732186317444 + }, + { + "loss": 2.51393402706493, + "loss_size": 0.1777868609536778, + "loss_pdi": 0.6462894515557722, + "loss_ee": 0.7651460604234175, + "loss_delivery": 0.5493429905988954, + "loss_biodist": 0.317649554122578, + "loss_toxic": 0.05771913768892938 + }, + { + "loss": 2.425347761674361, + "loss_size": 0.14695054224946283, + "loss_pdi": 0.6508165923031893, + "loss_ee": 0.7596850720318881, + "loss_delivery": 0.49562479284676636, + "loss_biodist": 0.31901306862180884, + "loss_toxic": 0.053257653659040276 + }, + { + "loss": 2.450036742470481, + "loss_size": 0.15424400500275873, + "loss_pdi": 0.655444394458424, + "loss_ee": 0.7591585137627341, + "loss_delivery": 0.5164471390572462, + "loss_biodist": 0.3094088543545116, + "loss_toxic": 0.05533381686969237 + }, + { + "loss": 2.3860875476490366, + "loss_size": 0.1556895158507607, + "loss_pdi": 0.6400747786868702, + "loss_ee": 0.7600220929492604, + "loss_delivery": 0.46673478253863077, + "loss_biodist": 0.3111781979149038, + "loss_toxic": 0.05238820222968405 + }, + { + "loss": 2.3954180804165928, + "loss_size": 0.1462622725150802, + "loss_pdi": 0.619796259836717, + "loss_ee": 0.7609411152926359, + "loss_delivery": 0.5080444128675894, + "loss_biodist": 0.30634408511898736, + "loss_toxic": 0.05402998389168219 + }, + { + "loss": 2.3838103034279565, + "loss_size": 0.13035081394694067, + "loss_pdi": 0.6347457062114369, + "loss_ee": 0.7560921203006398, + "loss_delivery": 0.4973590536551042, + "loss_biodist": 0.3073451383547349, + "loss_toxic": 0.05791747265241363 + }, + { + "loss": 2.4303664727644487, + "loss_size": 0.1789045144211162, + "loss_pdi": 0.631924033164978, + "loss_ee": 0.7499593008648265, + "loss_delivery": 0.502868037332188, + "loss_biodist": 0.3101512383330952, + "loss_toxic": 0.05655933002179319 + } + ], + "val": [] +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_1/model.pt b/models/nested_cv/20260130_183653/outer_fold_1/model.pt new file mode 100644 index 0000000..28e289a --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_1/model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1c9190b09f7ab586a7d5190b74452f4be67c5d2fb753aae202f9b723523ef00 +size 55606866 diff --git a/models/nested_cv/20260130_183653/outer_fold_1/optuna_study.sqlite3 b/models/nested_cv/20260130_183653/outer_fold_1/optuna_study.sqlite3 new file mode 100644 index 0000000..199b993 Binary files /dev/null and b/models/nested_cv/20260130_183653/outer_fold_1/optuna_study.sqlite3 differ diff --git a/models/nested_cv/20260130_183653/outer_fold_1/splits.json b/models/nested_cv/20260130_183653/outer_fold_1/splits.json new file mode 100644 index 0000000..d675eb8 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_1/splits.json @@ -0,0 +1 @@ +{"outer_train_idx": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 60, 61, 62, 64, 65, 70, 71, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 84, 87, 89, 90, 91, 92, 93, 96, 98, 99, 101, 102, 103, 105, 106, 108, 111, 112, 114, 115, 116, 119, 120, 121, 122, 123, 124, 125, 126, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 148, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 164, 165, 166, 167, 168, 171, 172, 173, 174, 175, 176, 177, 178, 179, 181, 185, 186, 187, 190, 191, 192, 196, 198, 203, 204, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 246, 248, 249, 250, 251, 253, 254, 255, 256, 257, 259, 260, 261, 262, 264, 265, 266, 267, 268, 269, 270, 271, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 295, 296, 298, 299, 300, 301, 302, 303, 304, 306, 307, 308, 309, 310, 312, 313, 314, 316, 317, 318, 319, 320, 321, 322, 323, 324, 326, 327, 329, 330, 331, 332, 334, 335, 336, 337, 340, 341, 342, 344, 346, 347, 348, 349, 350, 351, 352, 353, 354, 356, 357, 359, 360, 361, 362, 363, 364, 365, 366, 369, 370, 372, 374, 376, 377, 378, 380, 381, 382, 384, 385, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 421, 422, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433], "outer_test_idx": [29, 32, 35, 46, 48, 53, 63, 66, 67, 68, 69, 78, 83, 85, 86, 88, 94, 95, 97, 100, 104, 107, 109, 110, 113, 117, 118, 127, 128, 146, 147, 149, 150, 163, 169, 170, 180, 182, 183, 184, 188, 189, 193, 194, 195, 197, 199, 200, 201, 202, 205, 214, 228, 229, 245, 247, 252, 258, 263, 272, 293, 294, 297, 305, 311, 315, 325, 328, 333, 338, 339, 343, 345, 355, 358, 367, 368, 371, 373, 375, 379, 383, 386, 408, 409, 420, 423]} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_1/test_metrics.json b/models/nested_cv/20260130_183653/outer_fold_1/test_metrics.json new file mode 100644 index 0000000..929b017 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_1/test_metrics.json @@ -0,0 +1,42 @@ +{ + "size": { + "n_samples": 87, + "mse": 0.20102046206409732, + "rmse": 0.44835305515196094, + "mae": 0.2436335881551107, + "r2": 0.1900957049146058 + }, + "delivery": { + "n_samples": 62, + "mse": 0.5899936276993041, + "rmse": 0.7681104267612203, + "mae": 0.46896366539576484, + "r2": 0.4017743543115936 + }, + "pdi": { + "n_samples": 87, + "accuracy": 0.6666666666666666, + "precision": 0.41556437389770723, + "recall": 0.4042119565217391, + "f1": 0.40777777777777774 + }, + "ee": { + "n_samples": 87, + "accuracy": 0.6666666666666666, + "precision": 0.6414141414141414, + "recall": 0.6782661782661782, + "f1": 0.6387485970819304 + }, + "toxic": { + "n_samples": 62, + "accuracy": 1.0, + "precision": 1.0, + "recall": 1.0, + "f1": 1.0 + }, + "biodist": { + "n_samples": 62, + "kl_divergence": 0.30758161166563297, + "js_divergence": 0.08759221465023677 + } +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_2/best_params.json b/models/nested_cv/20260130_183653/outer_fold_2/best_params.json new file mode 100644 index 0000000..c74742c --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_2/best_params.json @@ -0,0 +1,10 @@ +{ + "d_model": 512, + "num_heads": 8, + "n_attn_layers": 4, + "fusion_strategy": "attention", + "head_hidden_dim": 64, + "dropout": 0.11433976976282646, + "lr": 5.3015812445144804e-05, + "weight_decay": 6.704431817743382e-06 +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_2/epoch_mean.json b/models/nested_cv/20260130_183653/outer_fold_2/epoch_mean.json new file mode 100644 index 0000000..c288387 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_2/epoch_mean.json @@ -0,0 +1 @@ +{"epoch_mean": 18} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_2/history.json b/models/nested_cv/20260130_183653/outer_fold_2/history.json new file mode 100644 index 0000000..49f665e --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_2/history.json @@ -0,0 +1,167 @@ +{ + "train": [ + { + "loss": 16.328427488153633, + "loss_size": 10.961490392684937, + "loss_pdi": 1.3189676566557451, + "loss_ee": 1.0652603994716296, + "loss_delivery": 1.093609056012197, + "loss_biodist": 1.2020217722112483, + "loss_toxic": 0.6870784759521484 + }, + { + "loss": 5.557661620053378, + "loss_size": 0.6549029702490027, + "loss_pdi": 1.2172227664427324, + "loss_ee": 1.0503052473068237, + "loss_delivery": 0.9982107078487222, + "loss_biodist": 1.0446247458457947, + "loss_toxic": 0.5923952026800676 + }, + { + "loss": 4.66534686088562, + "loss_size": 0.40050424906340515, + "loss_pdi": 0.9867187413302335, + "loss_ee": 0.9837820909240029, + "loss_delivery": 1.079543113708496, + "loss_biodist": 0.7991420409896157, + "loss_toxic": 0.41565650972453033 + }, + { + "loss": 4.042153250087392, + "loss_size": 0.39615295827388763, + "loss_pdi": 0.8758815581148321, + "loss_ee": 0.9381269162351434, + "loss_delivery": 0.9179368994452737, + "loss_biodist": 0.6232685229995034, + "loss_toxic": 0.2907863679257306 + }, + { + "loss": 3.5052005377682773, + "loss_size": 0.3407063511284915, + "loss_pdi": 0.8075345104390924, + "loss_ee": 0.8666415322910656, + "loss_delivery": 0.821596086025238, + "loss_biodist": 0.49901550195433875, + "loss_toxic": 0.16970657895911823 + }, + { + "loss": 3.227808865633878, + "loss_size": 0.3172220086509531, + "loss_pdi": 0.7554353800686923, + "loss_ee": 0.8366058685562827, + "loss_delivery": 0.7592829249121926, + "loss_biodist": 0.41754513708027924, + "loss_toxic": 0.14171750030734323 + }, + { + "loss": 3.015385866165161, + "loss_size": 0.3014552891254425, + "loss_pdi": 0.7061856659975919, + "loss_ee": 0.8024854118173773, + "loss_delivery": 0.6936024874448776, + "loss_biodist": 0.37025675448504364, + "loss_toxic": 0.1414001943035559 + }, + { + "loss": 2.8741718639026987, + "loss_size": 0.30145836960185657, + "loss_pdi": 0.6485596244985407, + "loss_ee": 0.7874462875452909, + "loss_delivery": 0.7069157646460966, + "loss_biodist": 0.31511187282475556, + "loss_toxic": 0.11467998136173595 + }, + { + "loss": 2.7043218179182573, + "loss_size": 0.30878364227034827, + "loss_pdi": 0.6446920850060203, + "loss_ee": 0.763309196992354, + "loss_delivery": 0.6044047027826309, + "loss_biodist": 0.2966968539086255, + "loss_toxic": 0.08643539994955063 + }, + { + "loss": 2.592982042919506, + "loss_size": 0.301474930210547, + "loss_pdi": 0.586654855446382, + "loss_ee": 0.7498630989681591, + "loss_delivery": 0.5886639884927056, + "loss_biodist": 0.2792905040762641, + "loss_toxic": 0.08703470907428047 + }, + { + "loss": 2.475934158671986, + "loss_size": 0.3227222941138528, + "loss_pdi": 0.5525102777914568, + "loss_ee": 0.7191228595646945, + "loss_delivery": 0.5444187271324071, + "loss_biodist": 0.254928475076502, + "loss_toxic": 0.08223151076923717 + }, + { + "loss": 2.4397390105507593, + "loss_size": 0.28487288003618066, + "loss_pdi": 0.5408996912566099, + "loss_ee": 0.7313735810193148, + "loss_delivery": 0.5359987101771615, + "loss_biodist": 0.2598937614397569, + "loss_toxic": 0.08670033531432803 + }, + { + "loss": 2.454329328103499, + "loss_size": 0.2546617639335719, + "loss_pdi": 0.549032907594334, + "loss_ee": 0.723411272872578, + "loss_delivery": 0.6058986783027649, + "loss_biodist": 0.2590403082695874, + "loss_toxic": 0.06228439848531376 + }, + { + "loss": 2.459202441302213, + "loss_size": 0.30333411490375345, + "loss_pdi": 0.5448255024173043, + "loss_ee": 0.7321870706298135, + "loss_delivery": 0.5674931427294557, + "loss_biodist": 0.2446274757385254, + "loss_toxic": 0.06673513505269181 + }, + { + "loss": 2.4072633764960547, + "loss_size": 0.3075956946069544, + "loss_pdi": 0.5372236939993772, + "loss_ee": 0.6942113583738153, + "loss_delivery": 0.541759736158631, + "loss_biodist": 0.2502128630876541, + "loss_toxic": 0.07625993527472019 + }, + { + "loss": 2.3814159631729126, + "loss_size": 0.2884528948502107, + "loss_pdi": 0.5177338475530798, + "loss_ee": 0.7055766040628607, + "loss_delivery": 0.564139807766134, + "loss_biodist": 0.24335274642164056, + "loss_toxic": 0.062160092321309174 + }, + { + "loss": 2.288854566487399, + "loss_size": 0.2864968058737842, + "loss_pdi": 0.5151266103441065, + "loss_ee": 0.6858331777832725, + "loss_delivery": 0.492813152345744, + "loss_biodist": 0.23307398774407126, + "loss_toxic": 0.07551077617840334 + }, + { + "loss": 2.308527339588512, + "loss_size": 0.3206997500224547, + "loss_pdi": 0.48468891328031366, + "loss_ee": 0.6778702302412554, + "loss_delivery": 0.5191753757270899, + "loss_biodist": 0.2486932264132933, + "loss_toxic": 0.05739984508942474 + } + ], + "val": [] +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_2/model.pt b/models/nested_cv/20260130_183653/outer_fold_2/model.pt new file mode 100644 index 0000000..e3af4ab --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_2/model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6285ba67132f7fff122c780ba4a63a2f52ed2b8a8775342015c1ccf00e2351e9 +size 107113964 diff --git a/models/nested_cv/20260130_183653/outer_fold_2/optuna_study.sqlite3 b/models/nested_cv/20260130_183653/outer_fold_2/optuna_study.sqlite3 new file mode 100644 index 0000000..85c032a Binary files /dev/null and b/models/nested_cv/20260130_183653/outer_fold_2/optuna_study.sqlite3 differ diff --git a/models/nested_cv/20260130_183653/outer_fold_2/splits.json b/models/nested_cv/20260130_183653/outer_fold_2/splits.json new file mode 100644 index 0000000..0a00b87 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_2/splits.json @@ -0,0 +1 @@ +{"outer_train_idx": [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 18, 20, 21, 22, 23, 24, 25, 27, 29, 31, 32, 35, 36, 39, 40, 41, 42, 43, 45, 46, 47, 48, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 63, 65, 66, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 80, 81, 83, 84, 85, 86, 87, 88, 89, 91, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 145, 146, 147, 148, 149, 150, 151, 153, 154, 155, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 199, 200, 201, 202, 203, 204, 205, 206, 207, 209, 210, 212, 213, 214, 215, 216, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 234, 235, 239, 240, 241, 242, 243, 244, 245, 246, 247, 249, 250, 252, 253, 254, 255, 258, 259, 260, 262, 263, 264, 265, 267, 268, 272, 274, 275, 276, 277, 278, 280, 281, 283, 284, 285, 286, 287, 289, 290, 291, 293, 294, 296, 297, 298, 299, 300, 301, 305, 308, 309, 310, 311, 313, 314, 315, 317, 318, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 342, 343, 344, 345, 348, 349, 352, 353, 355, 356, 357, 358, 359, 360, 361, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 375, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 392, 395, 396, 397, 399, 400, 401, 403, 404, 405, 406, 407, 408, 409, 410, 412, 413, 414, 415, 416, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431], "outer_test_idx": [3, 14, 15, 17, 19, 26, 28, 30, 33, 34, 37, 38, 44, 49, 50, 62, 64, 73, 79, 82, 90, 92, 93, 105, 106, 121, 132, 136, 144, 152, 156, 158, 173, 187, 198, 208, 211, 217, 218, 233, 236, 237, 238, 248, 251, 256, 257, 261, 266, 269, 270, 271, 273, 279, 282, 288, 292, 295, 302, 303, 304, 306, 307, 312, 316, 319, 340, 341, 346, 347, 350, 351, 354, 362, 374, 376, 390, 391, 393, 394, 398, 402, 411, 417, 418, 432, 433]} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_2/test_metrics.json b/models/nested_cv/20260130_183653/outer_fold_2/test_metrics.json new file mode 100644 index 0000000..b30b958 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_2/test_metrics.json @@ -0,0 +1,42 @@ +{ + "size": { + "n_samples": 85, + "mse": 0.07855156229299931, + "rmse": 0.28027051627490057, + "mae": 0.2201253890991211, + "r2": 0.009407939412061861 + }, + "delivery": { + "n_samples": 62, + "mse": 0.4162270771403472, + "rmse": 0.645156629928227, + "mae": 0.41306305523856635, + "r2": 0.37384819758564136 + }, + "pdi": { + "n_samples": 87, + "accuracy": 0.7126436781609196, + "precision": 0.3963383838383838, + "recall": 0.5856060606060606, + "f1": 0.43013891331106036 + }, + "ee": { + "n_samples": 87, + "accuracy": 0.6781609195402298, + "precision": 0.6215366001209922, + "recall": 0.65781362712309, + "f1": 0.6235867752721687 + }, + "toxic": { + "n_samples": 62, + "accuracy": 0.967741935483871, + "precision": 0.8, + "recall": 0.9830508474576272, + "f1": 0.8663793103448275 + }, + "biodist": { + "n_samples": 62, + "kl_divergence": 0.3189032824921648, + "js_divergence": 0.07944133611635379 + } +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_3/best_params.json b/models/nested_cv/20260130_183653/outer_fold_3/best_params.json new file mode 100644 index 0000000..2e72c2c --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_3/best_params.json @@ -0,0 +1,10 @@ +{ + "d_model": 512, + "num_heads": 4, + "n_attn_layers": 5, + "fusion_strategy": "attention", + "head_hidden_dim": 64, + "dropout": 0.11746271741188277, + "lr": 0.0001939220403760229, + "weight_decay": 2.4722550292920085e-06 +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_3/epoch_mean.json b/models/nested_cv/20260130_183653/outer_fold_3/epoch_mean.json new file mode 100644 index 0000000..17fb17e --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_3/epoch_mean.json @@ -0,0 +1 @@ +{"epoch_mean": 10} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_3/history.json b/models/nested_cv/20260130_183653/outer_fold_3/history.json new file mode 100644 index 0000000..17ecbc4 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_3/history.json @@ -0,0 +1,95 @@ +{ + "train": [ + { + "loss": 9.790851376273416, + "loss_size": 4.886490476402369, + "loss_pdi": 1.3631455573168667, + "loss_ee": 1.088216781616211, + "loss_delivery": 0.7040194340727546, + "loss_biodist": 1.1505104249173945, + "loss_toxic": 0.5984687263315375 + }, + { + "loss": 4.359956351193515, + "loss_size": 0.43619183789600025, + "loss_pdi": 1.1665386232462795, + "loss_ee": 0.9901693842627786, + "loss_delivery": 0.6073603630065918, + "loss_biodist": 0.7433811046860435, + "loss_toxic": 0.41631498661908234 + }, + { + "loss": 3.756383180618286, + "loss_size": 0.4402667825872248, + "loss_pdi": 1.0114682045849888, + "loss_ee": 0.9900745803659613, + "loss_delivery": 0.5448389879681848, + "loss_biodist": 0.5919120474295183, + "loss_toxic": 0.17782256481322375 + }, + { + "loss": 3.4846310182051226, + "loss_size": 0.46992606737396936, + "loss_pdi": 1.000607517632571, + "loss_ee": 0.8866635181687095, + "loss_delivery": 0.5218790579925884, + "loss_biodist": 0.47943955388936127, + "loss_toxic": 0.12611534412611614 + }, + { + "loss": 3.0496469844471323, + "loss_size": 0.36307603120803833, + "loss_pdi": 0.8360648371956565, + "loss_ee": 0.8720264759930697, + "loss_delivery": 0.45130031081763183, + "loss_biodist": 0.4275714944709431, + "loss_toxic": 0.09960775822401047 + }, + { + "loss": 2.8820750496604224, + "loss_size": 0.35799529267983005, + "loss_pdi": 0.7969891158017245, + "loss_ee": 0.8362645967440172, + "loss_delivery": 0.4626781473105604, + "loss_biodist": 0.358451091430404, + "loss_toxic": 0.0696967878294262 + }, + { + "loss": 2.645532087846236, + "loss_size": 0.3313806341453032, + "loss_pdi": 0.7585093595764854, + "loss_ee": 0.797249972820282, + "loss_delivery": 0.38693068108775397, + "loss_biodist": 0.29899653656916186, + "loss_toxic": 0.07246483176607978 + }, + { + "loss": 2.5514606779271904, + "loss_size": 0.33785366063768213, + "loss_pdi": 0.710435076193376, + "loss_ee": 0.7790363105860624, + "loss_delivery": 0.3592266860333356, + "loss_biodist": 0.2742794535376809, + "loss_toxic": 0.09062948763709176 + }, + { + "loss": 2.4680945114655928, + "loss_size": 0.3564724339680238, + "loss_pdi": 0.6821299493312836, + "loss_ee": 0.7459230910647999, + "loss_delivery": 0.36073175072669983, + "loss_biodist": 0.26153207096186554, + "loss_toxic": 0.061305238187990406 + }, + { + "loss": 2.3807498541745273, + "loss_size": 0.29239929941567505, + "loss_pdi": 0.7071054875850677, + "loss_ee": 0.7487604834816672, + "loss_delivery": 0.34709482369097794, + "loss_biodist": 0.2424463846466758, + "loss_toxic": 0.04294342412190004 + } + ], + "val": [] +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_3/model.pt b/models/nested_cv/20260130_183653/outer_fold_3/model.pt new file mode 100644 index 0000000..20ff867 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_3/model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93cd6e31c539b5604fa03b14699fc985cbee10dd8cbbb05f6cebc43f65e394da +size 132340732 diff --git a/models/nested_cv/20260130_183653/outer_fold_3/optuna_study.sqlite3 b/models/nested_cv/20260130_183653/outer_fold_3/optuna_study.sqlite3 new file mode 100644 index 0000000..4152d36 Binary files /dev/null and b/models/nested_cv/20260130_183653/outer_fold_3/optuna_study.sqlite3 differ diff --git a/models/nested_cv/20260130_183653/outer_fold_3/splits.json b/models/nested_cv/20260130_183653/outer_fold_3/splits.json new file mode 100644 index 0000000..d771b6b --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_3/splits.json @@ -0,0 +1 @@ +{"outer_train_idx": [0, 1, 2, 3, 7, 8, 12, 13, 14, 15, 17, 19, 20, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 46, 48, 49, 50, 51, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 76, 78, 79, 80, 82, 83, 85, 86, 88, 89, 90, 91, 92, 93, 94, 95, 97, 99, 100, 101, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 127, 128, 129, 130, 132, 133, 134, 136, 138, 140, 143, 144, 145, 146, 147, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 175, 177, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 197, 198, 199, 200, 201, 202, 203, 205, 206, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 220, 222, 225, 226, 227, 228, 229, 231, 232, 233, 235, 236, 237, 238, 239, 240, 241, 242, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 265, 266, 267, 269, 270, 271, 272, 273, 276, 277, 278, 279, 280, 281, 282, 284, 285, 286, 288, 289, 292, 293, 294, 295, 296, 297, 299, 300, 301, 302, 303, 304, 305, 306, 307, 309, 310, 311, 312, 314, 315, 316, 318, 319, 320, 321, 324, 325, 328, 329, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 365, 366, 367, 368, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 383, 384, 385, 386, 387, 389, 390, 391, 393, 394, 396, 397, 398, 401, 402, 405, 406, 407, 408, 409, 410, 411, 413, 414, 415, 416, 417, 418, 420, 421, 423, 424, 425, 427, 428, 430, 431, 432, 433], "outer_test_idx": [4, 5, 6, 9, 10, 11, 16, 18, 22, 23, 24, 39, 40, 45, 47, 52, 70, 75, 77, 81, 84, 87, 96, 98, 102, 114, 126, 131, 135, 137, 139, 141, 142, 148, 155, 172, 176, 178, 179, 196, 204, 207, 219, 221, 223, 224, 230, 234, 243, 244, 264, 268, 274, 275, 283, 287, 290, 291, 298, 308, 313, 317, 322, 323, 326, 327, 330, 331, 332, 349, 360, 364, 369, 370, 382, 388, 392, 395, 399, 400, 403, 404, 412, 419, 422, 426, 429]} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_3/test_metrics.json b/models/nested_cv/20260130_183653/outer_fold_3/test_metrics.json new file mode 100644 index 0000000..0086772 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_3/test_metrics.json @@ -0,0 +1,42 @@ +{ + "size": { + "n_samples": 87, + "mse": 0.08391054707837464, + "rmse": 0.28967317286620564, + "mae": 0.22691357272794876, + "r2": 0.26719931627457305 + }, + "delivery": { + "n_samples": 62, + "mse": 2.0453160934060053, + "rmse": 1.4301454798047664, + "mae": 0.5443450972558029, + "r2": 0.0777919381248442 + }, + "pdi": { + "n_samples": 87, + "accuracy": 0.6896551724137931, + "precision": 0.3994245524296675, + "recall": 0.5978021978021978, + "f1": 0.417895167895168 + }, + "ee": { + "n_samples": 87, + "accuracy": 0.5862068965517241, + "precision": 0.5469462969462969, + "recall": 0.5874125874125874, + "f1": 0.5375569894616209 + }, + "toxic": { + "n_samples": 63, + "accuracy": 0.9682539682539683, + "precision": 0.8, + "recall": 0.9833333333333334, + "f1": 0.8665254237288135 + }, + "biodist": { + "n_samples": 63, + "kl_divergence": 0.28367776789683485, + "js_divergence": 0.07318286384043993 + } +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_4/best_params.json b/models/nested_cv/20260130_183653/outer_fold_4/best_params.json new file mode 100644 index 0000000..bd058cf --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_4/best_params.json @@ -0,0 +1,10 @@ +{ + "d_model": 128, + "num_heads": 4, + "n_attn_layers": 6, + "fusion_strategy": "max", + "head_hidden_dim": 128, + "dropout": 0.15658744776638445, + "lr": 0.00031005155898680676, + "weight_decay": 5.422040924441196e-05 +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_4/epoch_mean.json b/models/nested_cv/20260130_183653/outer_fold_4/epoch_mean.json new file mode 100644 index 0000000..82d35d4 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_4/epoch_mean.json @@ -0,0 +1 @@ +{"epoch_mean": 14} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_4/history.json b/models/nested_cv/20260130_183653/outer_fold_4/history.json new file mode 100644 index 0000000..4223851 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_4/history.json @@ -0,0 +1,131 @@ +{ + "train": [ + { + "loss": 15.068541526794434, + "loss_size": 9.661331350153143, + "loss_pdi": 1.3961028402501887, + "loss_ee": 1.1041727607900447, + "loss_delivery": 1.131193908778104, + "loss_biodist": 1.1299291415648027, + "loss_toxic": 0.6458113518628207 + }, + { + "loss": 5.692212278192693, + "loss_size": 0.9680004715919495, + "loss_pdi": 1.2202669165351174, + "loss_ee": 1.0751874880357222, + "loss_delivery": 1.075642000545155, + "loss_biodist": 0.8836204951459711, + "loss_toxic": 0.46949480880390515 + }, + { + "loss": 4.2831949970938945, + "loss_size": 0.4278720129619945, + "loss_pdi": 1.0048133405772122, + "loss_ee": 0.9688023762269453, + "loss_delivery": 0.9665126570246436, + "loss_biodist": 0.649872666055506, + "loss_toxic": 0.2653219618580558 + }, + { + "loss": 3.620699340646917, + "loss_size": 0.34346921877427533, + "loss_pdi": 0.909794731573625, + "loss_ee": 0.8586091019890525, + "loss_delivery": 0.8243018510666761, + "loss_biodist": 0.5371287275444377, + "loss_toxic": 0.14739569039507347 + }, + { + "loss": 3.2791861187327993, + "loss_size": 0.3260936520316384, + "loss_pdi": 0.8282042199915106, + "loss_ee": 0.8434794707731768, + "loss_delivery": 0.7491147694262591, + "loss_biodist": 0.4268476908857172, + "loss_toxic": 0.10544633001766422 + }, + { + "loss": 3.000173742120916, + "loss_size": 0.3082492568276145, + "loss_pdi": 0.6911327242851257, + "loss_ee": 0.7926767305894331, + "loss_delivery": 0.7263242087580941, + "loss_biodist": 0.3646425713192333, + "loss_toxic": 0.11714824627746236 + }, + { + "loss": 2.8449384082447398, + "loss_size": 0.3491906361146407, + "loss_pdi": 0.7018052475018934, + "loss_ee": 0.7752535722472451, + "loss_delivery": 0.6207486174323342, + "loss_biodist": 0.3087055358019742, + "loss_toxic": 0.08923475528982552 + }, + { + "loss": 2.762920791452581, + "loss_size": 0.27234369922767987, + "loss_pdi": 0.6848637001080946, + "loss_ee": 0.7332382093776356, + "loss_delivery": 0.7119220888072794, + "loss_biodist": 0.27469146658073773, + "loss_toxic": 0.08586158154701645 + }, + { + "loss": 2.6228579716248945, + "loss_size": 0.28220029175281525, + "loss_pdi": 0.6514047113331881, + "loss_ee": 0.7080714106559753, + "loss_delivery": 0.6849517280405218, + "loss_biodist": 0.23796464096416126, + "loss_toxic": 0.05826510641385208 + }, + { + "loss": 2.53473145311529, + "loss_size": 0.24789857864379883, + "loss_pdi": 0.6621171263131228, + "loss_ee": 0.7045791528441689, + "loss_delivery": 0.6228310723196376, + "loss_biodist": 0.24999592927369205, + "loss_toxic": 0.047309636138379574 + }, + { + "loss": 2.439385262402621, + "loss_size": 0.2714946825395931, + "loss_pdi": 0.6569395525888964, + "loss_ee": 0.679518764669245, + "loss_delivery": 0.5314246158708226, + "loss_biodist": 0.23756521533836017, + "loss_toxic": 0.062442497265609825 + }, + { + "loss": 2.4029004248705776, + "loss_size": 0.25774127651344647, + "loss_pdi": 0.5974260785362937, + "loss_ee": 0.6959952928803184, + "loss_delivery": 0.5820360441099514, + "loss_biodist": 0.22288588908585635, + "loss_toxic": 0.04681587845764377 + }, + { + "loss": 2.390383416956121, + "loss_size": 0.24576269225640732, + "loss_pdi": 0.6070916679772463, + "loss_ee": 0.6997986544262279, + "loss_delivery": 0.5448949479243972, + "loss_biodist": 0.2274868596683849, + "loss_toxic": 0.065348598936742 + }, + { + "loss": 2.450988466089422, + "loss_size": 0.21830374882979828, + "loss_pdi": 0.6097230477766558, + "loss_ee": 0.6919564171270891, + "loss_delivery": 0.6387598026882518, + "loss_biodist": 0.22940932078794998, + "loss_toxic": 0.06283610728992657 + } + ], + "val": [] +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_4/model.pt b/models/nested_cv/20260130_183653/outer_fold_4/model.pt new file mode 100644 index 0000000..29c46be --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_4/model.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd138dcfe3697ec06e8caa6ec6c35a16635b1eb5a790b07e5c49eaf638dd9379 +size 11112658 diff --git a/models/nested_cv/20260130_183653/outer_fold_4/optuna_study.sqlite3 b/models/nested_cv/20260130_183653/outer_fold_4/optuna_study.sqlite3 new file mode 100644 index 0000000..5596a95 Binary files /dev/null and b/models/nested_cv/20260130_183653/outer_fold_4/optuna_study.sqlite3 differ diff --git a/models/nested_cv/20260130_183653/outer_fold_4/splits.json b/models/nested_cv/20260130_183653/outer_fold_4/splits.json new file mode 100644 index 0000000..b788f25 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_4/splits.json @@ -0,0 +1 @@ +{"outer_train_idx": [3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 42, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 60, 62, 63, 64, 66, 67, 68, 69, 70, 72, 73, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 90, 92, 93, 94, 95, 96, 97, 98, 100, 101, 102, 104, 105, 106, 107, 109, 110, 111, 113, 114, 116, 117, 118, 120, 121, 122, 124, 125, 126, 127, 128, 130, 131, 132, 133, 135, 136, 137, 139, 141, 142, 144, 145, 146, 147, 148, 149, 150, 151, 152, 154, 155, 156, 157, 158, 160, 161, 162, 163, 164, 165, 166, 168, 169, 170, 172, 173, 174, 176, 178, 179, 180, 181, 182, 183, 184, 187, 188, 189, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 217, 218, 219, 220, 221, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 238, 240, 241, 242, 243, 244, 245, 246, 247, 248, 250, 251, 252, 253, 254, 256, 257, 258, 259, 261, 263, 264, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 283, 285, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 302, 303, 304, 305, 306, 307, 308, 311, 312, 313, 315, 316, 317, 318, 319, 320, 321, 322, 323, 325, 326, 327, 328, 330, 331, 332, 333, 334, 336, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 349, 350, 351, 354, 355, 356, 357, 358, 360, 361, 362, 364, 367, 368, 369, 370, 371, 373, 374, 375, 376, 377, 378, 379, 382, 383, 385, 386, 388, 390, 391, 392, 393, 394, 395, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 411, 412, 414, 417, 418, 419, 420, 422, 423, 424, 426, 427, 429, 430, 431, 432, 433], "outer_test_idx": [0, 1, 2, 8, 20, 21, 27, 41, 43, 56, 57, 58, 59, 61, 65, 71, 74, 89, 91, 99, 103, 108, 112, 115, 119, 123, 129, 134, 138, 140, 143, 153, 159, 167, 171, 175, 177, 185, 186, 190, 191, 192, 203, 215, 216, 222, 227, 231, 239, 249, 255, 260, 262, 265, 276, 284, 286, 301, 309, 310, 314, 324, 329, 335, 337, 348, 352, 353, 359, 363, 365, 366, 372, 380, 381, 384, 387, 389, 396, 410, 413, 415, 416, 421, 425, 428]} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/outer_fold_4/test_metrics.json b/models/nested_cv/20260130_183653/outer_fold_4/test_metrics.json new file mode 100644 index 0000000..c747ee2 --- /dev/null +++ b/models/nested_cv/20260130_183653/outer_fold_4/test_metrics.json @@ -0,0 +1,42 @@ +{ + "size": { + "n_samples": 86, + "mse": 0.17175675509369362, + "rmse": 0.41443546553558086, + "mae": 0.2695355332174966, + "r2": -0.28030669239092476 + }, + "delivery": { + "n_samples": 63, + "mse": 0.35801504662479033, + "rmse": 0.5983435857638906, + "mae": 0.4112293718545328, + "r2": 0.3391331751101827 + }, + "pdi": { + "n_samples": 86, + "accuracy": 0.7209302325581395, + "precision": 0.5444444444444444, + "recall": 0.746031746031746, + "f1": 0.5894308943089431 + }, + "ee": { + "n_samples": 86, + "accuracy": 0.5930232558139535, + "precision": 0.5120650953984287, + "recall": 0.5383076043453402, + "f1": 0.5184148203294006 + }, + "toxic": { + "n_samples": 64, + "accuracy": 0.9375, + "precision": 0.6666666666666666, + "recall": 0.967741935483871, + "f1": 0.7333333333333333 + }, + "biodist": { + "n_samples": 63, + "kl_divergence": 0.2828013605689281, + "js_divergence": 0.07105864428768947 + } +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/strata_info.json b/models/nested_cv/20260130_183653/strata_info.json new file mode 100644 index 0000000..283f890 --- /dev/null +++ b/models/nested_cv/20260130_183653/strata_info.json @@ -0,0 +1,60 @@ +{ + "original_strata_counts": { + "T0|P0|E0": "5", + "T0|P0|E1": "63", + "T0|P0|E2": "175", + "T0|P1|E0": "1", + "T0|P1|E1": "19", + "T0|P1|E2": "33", + "T0|P2|E2": "3", + "T1|P0|E2": "9", + "T1|P1|E2": "5", + "TNA|P0|E0": "28", + "TNA|P0|E1": "21", + "TNA|P0|E2": "20", + "TNA|P1|E0": "29", + "TNA|P1|E1": "7", + "TNA|P1|E2": "14", + "TNA|P2|E2": "1", + "TNA|P3|E0": "1" + }, + "rare_strata": [ + "T0|P1|E0", + "T0|P2|E2", + "TNA|P2|E2", + "TNA|P3|E0" + ], + "final_strata": [ + "RARE", + "T0|P0|E0", + "T0|P0|E1", + "T0|P0|E2", + "T0|P1|E1", + "T0|P1|E2", + "T1|P0|E2", + "T1|P1|E2", + "TNA|P0|E0", + "TNA|P0|E1", + "TNA|P0|E2", + "TNA|P1|E0", + "TNA|P1|E1", + "TNA|P1|E2" + ], + "final_strata_counts": { + "RARE": "6", + "T0|P0|E0": "5", + "T0|P0|E1": "63", + "T0|P0|E2": "175", + "T0|P1|E1": "19", + "T0|P1|E2": "33", + "T1|P0|E2": "9", + "T1|P1|E2": "5", + "TNA|P0|E0": "28", + "TNA|P0|E1": "21", + "TNA|P0|E2": "20", + "TNA|P1|E0": "29", + "TNA|P1|E1": "7", + "TNA|P1|E2": "14" + }, + "n_rare_merged": "6" +} \ No newline at end of file diff --git a/models/nested_cv/20260130_183653/summary.json b/models/nested_cv/20260130_183653/summary.json new file mode 100644 index 0000000..042552b --- /dev/null +++ b/models/nested_cv/20260130_183653/summary.json @@ -0,0 +1,342 @@ +{ + "fold_results": [ + { + "fold": 0, + "best_params": { + "d_model": 512, + "num_heads": 8, + "n_attn_layers": 5, + "fusion_strategy": "attention", + "head_hidden_dim": 128, + "dropout": 0.14666736316838325, + "lr": 0.0001295888795454003, + "weight_decay": 7.732380983243132e-05 + }, + "epoch_mean": 13, + "test_metrics": { + "size": { + "n_samples": 87, + "mse": 0.26366087871128757, + "rmse": 0.5134791901443403, + "mae": 0.25157783223294666, + "r2": 0.21208517006410577 + }, + "delivery": { + "n_samples": 61, + "mse": 0.40443344562739025, + "rmse": 0.63595082013265, + "mae": 0.3928790920429298, + "r2": 0.2300258531372983 + }, + "pdi": { + "n_samples": 87, + "accuracy": 0.7241379310344828, + "precision": 0.35141509433962265, + "recall": 0.35351966873706003, + "f1": 0.348405985686402 + }, + "ee": { + "n_samples": 87, + "accuracy": 0.6666666666666666, + "precision": 0.6188811188811189, + "recall": 0.6375291375291375, + "f1": 0.6217948717948718 + }, + "toxic": { + "n_samples": 62, + "accuracy": 0.967741935483871, + "precision": 0.8, + "recall": 0.9830508474576272, + "f1": 0.8663793103448275 + }, + "biodist": { + "n_samples": 61, + "kl_divergence": 0.14776465556036145, + "js_divergence": 0.03926150329301917 + } + } + }, + { + "fold": 1, + "best_params": { + "d_model": 512, + "num_heads": 4, + "n_attn_layers": 2, + "fusion_strategy": "avg", + "head_hidden_dim": 64, + "dropout": 0.05188345993471756, + "lr": 4.21188892865021e-05, + "weight_decay": 4.086499445232577e-05 + }, + "epoch_mean": 23, + "test_metrics": { + "size": { + "n_samples": 87, + "mse": 0.20102046206409732, + "rmse": 0.44835305515196094, + "mae": 0.2436335881551107, + "r2": 0.1900957049146058 + }, + "delivery": { + "n_samples": 62, + "mse": 0.5899936276993041, + "rmse": 0.7681104267612203, + "mae": 0.46896366539576484, + "r2": 0.4017743543115936 + }, + "pdi": { + "n_samples": 87, + "accuracy": 0.6666666666666666, + "precision": 0.41556437389770723, + "recall": 0.4042119565217391, + "f1": 0.40777777777777774 + }, + "ee": { + "n_samples": 87, + "accuracy": 0.6666666666666666, + "precision": 0.6414141414141414, + "recall": 0.6782661782661782, + "f1": 0.6387485970819304 + }, + "toxic": { + "n_samples": 62, + "accuracy": 1.0, + "precision": 1.0, + "recall": 1.0, + "f1": 1.0 + }, + "biodist": { + "n_samples": 62, + "kl_divergence": 0.30758161166563297, + "js_divergence": 0.08759221465023677 + } + } + }, + { + "fold": 2, + "best_params": { + "d_model": 512, + "num_heads": 8, + "n_attn_layers": 4, + "fusion_strategy": "attention", + "head_hidden_dim": 64, + "dropout": 0.11433976976282646, + "lr": 5.3015812445144804e-05, + "weight_decay": 6.704431817743382e-06 + }, + "epoch_mean": 18, + "test_metrics": { + "size": { + "n_samples": 85, + "mse": 0.07855156229299931, + "rmse": 0.28027051627490057, + "mae": 0.2201253890991211, + "r2": 0.009407939412061861 + }, + "delivery": { + "n_samples": 62, + "mse": 0.4162270771403472, + "rmse": 0.645156629928227, + "mae": 0.41306305523856635, + "r2": 0.37384819758564136 + }, + "pdi": { + "n_samples": 87, + "accuracy": 0.7126436781609196, + "precision": 0.3963383838383838, + "recall": 0.5856060606060606, + "f1": 0.43013891331106036 + }, + "ee": { + "n_samples": 87, + "accuracy": 0.6781609195402298, + "precision": 0.6215366001209922, + "recall": 0.65781362712309, + "f1": 0.6235867752721687 + }, + "toxic": { + "n_samples": 62, + "accuracy": 0.967741935483871, + "precision": 0.8, + "recall": 0.9830508474576272, + "f1": 0.8663793103448275 + }, + "biodist": { + "n_samples": 62, + "kl_divergence": 0.3189032824921648, + "js_divergence": 0.07944133611635379 + } + } + }, + { + "fold": 3, + "best_params": { + "d_model": 512, + "num_heads": 4, + "n_attn_layers": 5, + "fusion_strategy": "attention", + "head_hidden_dim": 64, + "dropout": 0.11746271741188277, + "lr": 0.0001939220403760229, + "weight_decay": 2.4722550292920085e-06 + }, + "epoch_mean": 10, + "test_metrics": { + "size": { + "n_samples": 87, + "mse": 0.08391054707837464, + "rmse": 0.28967317286620564, + "mae": 0.22691357272794876, + "r2": 0.26719931627457305 + }, + "delivery": { + "n_samples": 62, + "mse": 2.0453160934060053, + "rmse": 1.4301454798047664, + "mae": 0.5443450972558029, + "r2": 0.0777919381248442 + }, + "pdi": { + "n_samples": 87, + "accuracy": 0.6896551724137931, + "precision": 0.3994245524296675, + "recall": 0.5978021978021978, + "f1": 0.417895167895168 + }, + "ee": { + "n_samples": 87, + "accuracy": 0.5862068965517241, + "precision": 0.5469462969462969, + "recall": 0.5874125874125874, + "f1": 0.5375569894616209 + }, + "toxic": { + "n_samples": 63, + "accuracy": 0.9682539682539683, + "precision": 0.8, + "recall": 0.9833333333333334, + "f1": 0.8665254237288135 + }, + "biodist": { + "n_samples": 63, + "kl_divergence": 0.28367776789683485, + "js_divergence": 0.07318286384043993 + } + } + }, + { + "fold": 4, + "best_params": { + "d_model": 128, + "num_heads": 4, + "n_attn_layers": 6, + "fusion_strategy": "max", + "head_hidden_dim": 128, + "dropout": 0.15658744776638445, + "lr": 0.00031005155898680676, + "weight_decay": 5.422040924441196e-05 + }, + "epoch_mean": 14, + "test_metrics": { + "size": { + "n_samples": 86, + "mse": 0.17175675509369362, + "rmse": 0.41443546553558086, + "mae": 0.2695355332174966, + "r2": -0.28030669239092476 + }, + "delivery": { + "n_samples": 63, + "mse": 0.35801504662479033, + "rmse": 0.5983435857638906, + "mae": 0.4112293718545328, + "r2": 0.3391331751101827 + }, + "pdi": { + "n_samples": 86, + "accuracy": 0.7209302325581395, + "precision": 0.5444444444444444, + "recall": 0.746031746031746, + "f1": 0.5894308943089431 + }, + "ee": { + "n_samples": 86, + "accuracy": 0.5930232558139535, + "precision": 0.5120650953984287, + "recall": 0.5383076043453402, + "f1": 0.5184148203294006 + }, + "toxic": { + "n_samples": 64, + "accuracy": 0.9375, + "precision": 0.6666666666666666, + "recall": 0.967741935483871, + "f1": 0.7333333333333333 + }, + "biodist": { + "n_samples": 63, + "kl_divergence": 0.2828013605689281, + "js_divergence": 0.07105864428768947 + } + } + } + ], + "summary_stats": { + "size": { + "mse_mean": 0.1597800410480905, + "mse_std": 0.07069609368925868, + "rmse_mean": 0.3892422799945977, + "rmse_std": 0.09094222623565874, + "mae_mean": 0.24235718308652476, + "mae_std": 0.017652592224532318, + "r2_mean": 0.07969628765488435, + "r2_std": 0.19970720107145462 + }, + "delivery": { + "mse_mean": 0.7627970580995674, + "mse_std": 0.6460804607386303, + "rmse_mean": 0.8155413884781508, + "rmse_std": 0.31255287837212004, + "mae_mean": 0.44609605635751937, + "mae_std": 0.055343855526290356, + "r2_mean": 0.28451470365391207, + "r2_std": 0.11867334397685028 + }, + "pdi": { + "accuracy_mean": 0.7028067361668003, + "accuracy_std": 0.021722406295571848, + "precision_mean": 0.4214373697899651, + "precision_std": 0.06508897722192637, + "recall_mean": 0.5374343259397607, + "recall_std": 0.1421622175476529, + "f1_mean": 0.4387297477958702, + "f1_std": 0.0804178141542625 + }, + "ee": { + "accuracy_mean": 0.6381448810478482, + "accuracy_std": 0.039904343482039924, + "precision_mean": 0.5881686505521957, + "precision_std": 0.049765031116764336, + "recall_mean": 0.6198658269352666, + "recall_std": 0.05072984435881663, + "f1_mean": 0.5880204107879985, + "f1_std": 0.049740374945613494 + }, + "toxic": { + "accuracy_mean": 0.9682475678443421, + "accuracy_std": 0.01976937654694873, + "precision_mean": 0.8133333333333335, + "precision_std": 0.10666666666666666, + "recall_mean": 0.9834353927464917, + "recall_std": 0.010207614614579376, + "f1_mean": 0.8665234755503602, + "f1_std": 0.08432750219703594 + }, + "biodist": { + "kl_divergence_mean": 0.26814573563678445, + "kl_divergence_std": 0.06177240919631341, + "js_divergence_mean": 0.07010731243754784, + "js_divergence_std": 0.016460095953094674 + } + } +} \ No newline at end of file diff --git a/models/pretrain_cv/fold_0/model.pt b/models/pretrain_cv/fold_0/model.pt index a829bad..8b084af 100644 Binary files a/models/pretrain_cv/fold_0/model.pt and b/models/pretrain_cv/fold_0/model.pt differ diff --git a/models/pretrain_cv/fold_1/model.pt b/models/pretrain_cv/fold_1/model.pt index 3b5c236..622f9dd 100644 Binary files a/models/pretrain_cv/fold_1/model.pt and b/models/pretrain_cv/fold_1/model.pt differ diff --git a/models/pretrain_cv/fold_2/model.pt b/models/pretrain_cv/fold_2/model.pt index 742dbf4..8e5bd94 100644 Binary files a/models/pretrain_cv/fold_2/model.pt and b/models/pretrain_cv/fold_2/model.pt differ diff --git a/models/pretrain_cv/fold_3/model.pt b/models/pretrain_cv/fold_3/model.pt index 57ecbfb..d7fc87f 100644 Binary files a/models/pretrain_cv/fold_3/model.pt and b/models/pretrain_cv/fold_3/model.pt differ diff --git a/models/pretrain_cv/fold_4/model.pt b/models/pretrain_cv/fold_4/model.pt index 5953298..f56f568 100644 Binary files a/models/pretrain_cv/fold_4/model.pt and b/models/pretrain_cv/fold_4/model.pt differ diff --git a/models/pretrain_delivery.pt b/models/pretrain_delivery.pt index c4c0a04..b63a703 100644 Binary files a/models/pretrain_delivery.pt and b/models/pretrain_delivery.pt differ diff --git a/pixi.lock b/pixi.lock index e3f6800..8280a28 100644 --- a/pixi.lock +++ b/pixi.lock @@ -53,6 +53,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/xz-tools-5.8.1-hb9d3cd8_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/64/88/c7083fc61120ab661c5d0b82cb77079fc1429d3f913a456c1c82cf4658f7/alabaster-0.7.13-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/54/7e/ac0991d1745f7d755fc1cd381b3990a45b404b4d008fc75e2a983516fbfe/alembic-1.14.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9b/52/4a86a4fa1cc2aae79137cc9510b7080c3e5aede2310d14fae5486feec7f7/altair-5.4.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl @@ -65,6 +66,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/bf/fa/cf5bb2409a385f78750e78c8d2e24780964976acdaaed65dbd6083ae5b40/charset_normalizer-3.4.4-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/71/7f20855592cc929bc206810432b991ec4c702dc26b0567b132e52c85536f/contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/15/1a/c6eae628480aa1fc5f6f85437c7d8ec0d1172597acd1c61182202a902c0f/cramjam-2.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl @@ -82,6 +84,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/da/71/ae30dadffc90b9006d77af76b393cb9dfbfc9629f339fc1574a1c52e6806/future-1.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d8/88/0ce16c0afb2d71d85562a7bcd9b092fec80a7767ab5b5f7e1bbbca8200f8/greenlet-3.1.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl @@ -96,6 +99,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ee/07/44bd408781594c4d0a027666ef27fab1e441b109dc3b76b4f836f8fd04fe/jsonschema_specifications-2023.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/76/36/ae40d7a3171e06f55ac77fe5536079e7be1d8be2a8210e08975c7f9b4d54/kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/bd/50319665ce81bb10e90d1cf76f9e1aa269ea6f7fa30ab4521f14d122a3df/MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/30/33/cc27211d2ffeee4fd7402dca137b6e8a83f6dcae3d4be8d0ad5068555561/matplotlib-3.7.5-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl - pypi: https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl @@ -116,6 +120,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/46/0c/c75bbfb967457a0b7670b8ad267bfc4fffdf341c074e0a80db06c24ccfd4/nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl - pypi: https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/7f/12/cba81286cbaf0f0c3f0473846cfd992cb240bdcea816bf2ef7de8ed0f744/optuna-4.5.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f8/7f/5b047effafbdd34e52c9e2d7e44f729a0655efafb22198c45cf692cdc157/pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5d/e6/71ed4d95676098159b533c4a4c424cf453fec9614edaff1a0633fe228eef/pandas_flavor-0.7.0-py3-none-any.whl @@ -131,6 +136,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/25/a2/b725b61ac76a75583ae7104b3209f75ea44b13cfd026aa535ece22b7f22e/PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/3d/84/63b2e66f5c7cb97ce994769afbbef85a1ac364fedbcb7d4a3c0f15d318a5/rdkit-2024.3.5-cp38-cp38-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b7/59/2056f61236782a2c86b33906c025d4f4a0b17be0161b63b70fd9e8775d36/referencing-0.35.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl @@ -150,6 +156,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2b/14/05f9206cf4e9cfca1afb5fd224c7cd434dcc3a433d6d9e4e0264d29c6cdb/sphinxcontrib_qthelp-1.0.3-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c6/77/5464ec50dd0f1c1037e3c93249b040c8fc8078fdda97530eeb02424b6eea/sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/8a/7b/b9e0eb7f9a15f2e82856603c728edf14c54a07c6738ab228e4f2de049338/sqlalchemy-2.0.46-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b6/c5/7ae467eeddb57260c8ce17a3a09f9f5edba35820fc022d7c55b7decd5d3a/starlette-0.44.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9a/14/857d0734989f3d26f2f965b2e3f67568ea7a6e8a60cb9c1ed7f774b6d606/streamlit-1.40.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl @@ -206,6 +213,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xz-gpl-tools-5.8.1-h9a6d368_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xz-tools-5.8.1-h39f12f2_2.conda - pypi: https://files.pythonhosted.org/packages/64/88/c7083fc61120ab661c5d0b82cb77079fc1429d3f913a456c1c82cf4658f7/alabaster-0.7.13-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/54/7e/ac0991d1745f7d755fc1cd381b3990a45b404b4d008fc75e2a983516fbfe/alembic-1.14.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9b/52/4a86a4fa1cc2aae79137cc9510b7080c3e5aede2310d14fae5486feec7f7/altair-5.4.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl @@ -218,6 +226,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/0a/4e/3926a1c11f0433791985727965263f788af00db3482d89a7545ca5ecc921/charset_normalizer-3.4.4-cp38-cp38-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a6/82/29f5ff4ae074c3230e266bc9efef449ebde43721a727b989dd8ef8f97d73/contourpy-1.1.1-cp38-cp38-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/65/64/e34ee535519fd14cde3a7f3f8cd3b4ef54483b9df655e4180437eb884aab/cramjam-2.11.0-cp38-cp38-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl - pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl @@ -249,6 +258,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ee/07/44bd408781594c4d0a027666ef27fab1e441b109dc3b76b4f836f8fd04fe/jsonschema_specifications-2023.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/14/a7/bb8ab10e12cc8764f4da0245d72dee4731cc720bdec0f085d5e9c6005b98/kiwisolver-1.4.7-cp38-cp38-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f8/ff/2c942a82c35a49df5de3a630ce0a8456ac2969691b230e530ac12314364c/MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/aa/59/4d13e5b6298b1ca5525eea8c68d3806ae93ab6d0bb17ca9846aa3156b92b/matplotlib-3.7.5-cp38-cp38-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl @@ -257,6 +267,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/a8/05/9d4f9b78ead6b2661d6e8ea772e111fc4a9fbd866ad0c81906c11206b55e/networkx-3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a7/ae/f53b7b265fdc701e663fbb322a8e9d4b14d9cb7b2385f45ddfabfc4327e4/numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/7f/12/cba81286cbaf0f0c3f0473846cfd992cb240bdcea816bf2ef7de8ed0f744/optuna-4.5.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/53/c3/f8e87361f7fdf42012def602bfa2a593423c729f5cb7c97aed7f51be66ac/pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/5d/e6/71ed4d95676098159b533c4a4c424cf453fec9614edaff1a0633fe228eef/pandas_flavor-0.7.0-py3-none-any.whl @@ -272,6 +283,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz - pypi: https://files.pythonhosted.org/packages/bf/cb/c709b60f4815e18c00e1e8639204bdba04cb158e6278791d82f94f51a988/rdkit-2024.3.5-cp38-cp38-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b7/59/2056f61236782a2c86b33906c025d4f4a0b17be0161b63b70fd9e8775d36/referencing-0.35.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl @@ -291,6 +303,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2b/14/05f9206cf4e9cfca1afb5fd224c7cd434dcc3a433d6d9e4e0264d29c6cdb/sphinxcontrib_qthelp-1.0.3-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c6/77/5464ec50dd0f1c1037e3c93249b040c8fc8078fdda97530eeb02424b6eea/sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/07/d4/76b9618d7eb1e6a3c26734e1186f8ad7869e4426b1ea7dc425bc4c832e67/sqlalchemy-2.0.46-cp38-cp38-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b6/c5/7ae467eeddb57260c8ce17a3a09f9f5edba35820fc022d7c55b7decd5d3a/starlette-0.44.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9a/14/857d0734989f3d26f2f965b2e3f67568ea7a6e8a60cb9c1ed7f774b6d606/streamlit-1.40.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl @@ -336,6 +349,19 @@ packages: version: 0.7.13 sha256: 1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3 requires_python: '>=3.6' +- pypi: https://files.pythonhosted.org/packages/54/7e/ac0991d1745f7d755fc1cd381b3990a45b404b4d008fc75e2a983516fbfe/alembic-1.14.1-py3-none-any.whl + name: alembic + version: 1.14.1 + sha256: 1acdd7a3a478e208b0503cd73614d5e4c6efafa4e73518bb60e4f2846a37b1c5 + requires_dist: + - sqlalchemy>=1.3.0 + - mako + - importlib-metadata ; python_full_version < '3.9' + - importlib-resources ; python_full_version < '3.9' + - typing-extensions>=4 + - backports-zoneinfo ; python_full_version < '3.9' and extra == 'tz' + - tzdata ; extra == 'tz' + requires_python: '>=3.8' - pypi: https://files.pythonhosted.org/packages/9b/52/4a86a4fa1cc2aae79137cc9510b7080c3e5aede2310d14fae5486feec7f7/altair-5.4.1-py3-none-any.whl name: altair version: 5.4.1 @@ -589,6 +615,18 @@ packages: - pkg:pypi/colorama?source=hash-mapping size: 25170 timestamp: 1666700778190 +- pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl + name: colorlog + version: 6.10.1 + sha256: 2d7e8348291948af66122cff006c9f8da6255d224e7cf8e37d8de2df3bad8c9c + requires_dist: + - colorama ; sys_platform == 'win32' + - black ; extra == 'development' + - flake8 ; extra == 'development' + - mypy ; extra == 'development' + - pytest ; extra == 'development' + - types-colorama ; extra == 'development' + requires_python: '>=3.6' - pypi: https://files.pythonhosted.org/packages/8e/71/7f20855592cc929bc206810432b991ec4c702dc26b0567b132e52c85536f/contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl name: contourpy version: 1.1.1 @@ -1013,6 +1051,16 @@ packages: - sphinx-rtd-theme ; extra == 'doc' - sphinx-autodoc-typehints ; extra == 'doc' requires_python: '>=3.7' +- pypi: https://files.pythonhosted.org/packages/d8/88/0ce16c0afb2d71d85562a7bcd9b092fec80a7767ab5b5f7e1bbbca8200f8/greenlet-3.1.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl + name: greenlet + version: 3.1.1 + sha256: 85f3ff71e2e60bd4b4932a043fbbe0f499e263c628390b285cb599154a3b03b1 + requires_dist: + - sphinx ; extra == 'docs' + - furo ; extra == 'docs' + - objgraph ; extra == 'test' + - psutil ; extra == 'test' + requires_python: '>=3.7' - pypi: https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl name: h11 version: 0.16.0 @@ -1453,6 +1501,16 @@ packages: - pkg:pypi/loguru?source=hash-mapping size: 97617 timestamp: 1695547715271 +- pypi: https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl + name: mako + version: 1.3.10 + sha256: baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59 + requires_dist: + - markupsafe>=0.9.2 + - pytest ; extra == 'testing' + - babel ; extra == 'babel' + - lingua ; extra == 'lingua' + requires_python: '>=3.8' - conda: https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda sha256: c041b0eaf7a6af3344d5dd452815cdc148d6284fec25a4fa3f4263b3a021e962 md5: 93a8e71256479c62074356ef6ebf501b @@ -1709,6 +1767,73 @@ packages: purls: [] size: 3108371 timestamp: 1762839712322 +- pypi: https://files.pythonhosted.org/packages/7f/12/cba81286cbaf0f0c3f0473846cfd992cb240bdcea816bf2ef7de8ed0f744/optuna-4.5.0-py3-none-any.whl + name: optuna + version: 4.5.0 + sha256: 5b8a783e84e448b0742501bc27195344a28d2c77bd2feef5b558544d954851b0 + requires_dist: + - alembic>=1.5.0 + - colorlog + - numpy + - packaging>=20.0 + - sqlalchemy>=1.4.2 + - tqdm + - pyyaml + - asv>=0.5.0 ; extra == 'benchmark' + - cma ; extra == 'benchmark' + - virtualenv ; extra == 'benchmark' + - black ; extra == 'checking' + - blackdoc ; extra == 'checking' + - flake8 ; extra == 'checking' + - isort ; extra == 'checking' + - mypy ; extra == 'checking' + - mypy-boto3-s3 ; extra == 'checking' + - scipy-stubs ; python_full_version >= '3.10' and extra == 'checking' + - types-pyyaml ; extra == 'checking' + - types-redis ; extra == 'checking' + - types-setuptools ; extra == 'checking' + - types-tqdm ; extra == 'checking' + - typing-extensions>=3.10.0.0 ; extra == 'checking' + - ase ; extra == 'document' + - cmaes>=0.12.0 ; extra == 'document' + - fvcore ; extra == 'document' + - kaleido<0.4 ; extra == 'document' + - lightgbm ; extra == 'document' + - matplotlib!=3.6.0 ; extra == 'document' + - pandas ; extra == 'document' + - pillow ; extra == 'document' + - plotly>=4.9.0 ; extra == 'document' + - scikit-learn ; extra == 'document' + - sphinx ; extra == 'document' + - sphinx-copybutton ; extra == 'document' + - sphinx-gallery ; extra == 'document' + - sphinx-notfound-page ; extra == 'document' + - sphinx-rtd-theme>=1.2.0 ; extra == 'document' + - torch ; extra == 'document' + - torchvision ; extra == 'document' + - boto3 ; extra == 'optional' + - cmaes>=0.12.0 ; extra == 'optional' + - google-cloud-storage ; extra == 'optional' + - matplotlib!=3.6.0 ; extra == 'optional' + - pandas ; extra == 'optional' + - plotly>=4.9.0 ; extra == 'optional' + - redis ; extra == 'optional' + - scikit-learn>=0.24.2 ; extra == 'optional' + - scipy ; extra == 'optional' + - torch ; extra == 'optional' + - grpcio ; extra == 'optional' + - protobuf>=5.28.1 ; extra == 'optional' + - coverage ; extra == 'test' + - fakeredis[lua] ; extra == 'test' + - kaleido<0.4 ; extra == 'test' + - moto ; extra == 'test' + - pytest ; extra == 'test' + - pytest-xdist ; extra == 'test' + - scipy>=1.9.2 ; extra == 'test' + - torch ; extra == 'test' + - grpcio ; extra == 'test' + - protobuf>=5.28.1 ; extra == 'test' + requires_python: '>=3.8' - pypi: https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl name: packaging version: '24.2' @@ -2136,6 +2261,16 @@ packages: name: pytz version: '2025.2' sha256: 5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00 +- pypi: https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz + name: pyyaml + version: 6.0.3 + sha256: d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f + requires_python: '>=3.8' +- pypi: https://files.pythonhosted.org/packages/25/a2/b725b61ac76a75583ae7104b3209f75ea44b13cfd026aa535ece22b7f22e/PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl + name: pyyaml + version: 6.0.3 + sha256: 22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6 + requires_python: '>=3.8' - pypi: https://files.pythonhosted.org/packages/3d/84/63b2e66f5c7cb97ce994769afbbef85a1ac364fedbcb7d4a3c0f15d318a5/rdkit-2024.3.5-cp38-cp38-manylinux_2_28_x86_64.whl name: rdkit version: 2024.3.5 @@ -2554,6 +2689,82 @@ packages: - docutils-stubs ; extra == 'lint' - pytest ; extra == 'test' requires_python: '>=3.5' +- pypi: https://files.pythonhosted.org/packages/07/d4/76b9618d7eb1e6a3c26734e1186f8ad7869e4426b1ea7dc425bc4c832e67/sqlalchemy-2.0.46-cp38-cp38-macosx_11_0_arm64.whl + name: sqlalchemy + version: 2.0.46 + sha256: 6ac245604295b521de49b465bab845e3afe6916bcb2147e5929c8041b4ec0545 + requires_dist: + - typing-extensions>=4.6.0 + - greenlet>=1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64' + - importlib-metadata ; python_full_version < '3.8' + - greenlet>=1 ; extra == 'aiomysql' + - aiomysql>=0.2.0 ; extra == 'aiomysql' + - greenlet>=1 ; extra == 'aioodbc' + - aioodbc ; extra == 'aioodbc' + - greenlet>=1 ; extra == 'aiosqlite' + - aiosqlite ; extra == 'aiosqlite' + - typing-extensions!=3.10.0.1 ; extra == 'aiosqlite' + - greenlet>=1 ; extra == 'asyncio' + - greenlet>=1 ; extra == 'asyncmy' + - asyncmy>=0.2.3,!=0.2.4,!=0.2.6 ; extra == 'asyncmy' + - mariadb>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10 ; extra == 'mariadb-connector' + - pyodbc ; extra == 'mssql' + - pymssql ; extra == 'mssql-pymssql' + - pyodbc ; extra == 'mssql-pyodbc' + - mypy>=0.910 ; extra == 'mypy' + - mysqlclient>=1.4.0 ; extra == 'mysql' + - mysql-connector-python ; extra == 'mysql-connector' + - cx-oracle>=8 ; extra == 'oracle' + - oracledb>=1.0.1 ; extra == 'oracle-oracledb' + - psycopg2>=2.7 ; extra == 'postgresql' + - greenlet>=1 ; extra == 'postgresql-asyncpg' + - asyncpg ; extra == 'postgresql-asyncpg' + - pg8000>=1.29.1 ; extra == 'postgresql-pg8000' + - psycopg>=3.0.7 ; extra == 'postgresql-psycopg' + - psycopg2-binary ; extra == 'postgresql-psycopg2binary' + - psycopg2cffi ; extra == 'postgresql-psycopg2cffi' + - psycopg[binary]>=3.0.7 ; extra == 'postgresql-psycopgbinary' + - pymysql ; extra == 'pymysql' + - sqlcipher3-binary ; extra == 'sqlcipher' + requires_python: '>=3.7' +- pypi: https://files.pythonhosted.org/packages/8a/7b/b9e0eb7f9a15f2e82856603c728edf14c54a07c6738ab228e4f2de049338/sqlalchemy-2.0.46-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl + name: sqlalchemy + version: 2.0.46 + sha256: 716be5bcabf327b6d5d265dbdc6213a01199be587224eb991ad0d37e83d728fd + requires_dist: + - typing-extensions>=4.6.0 + - greenlet>=1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64' + - importlib-metadata ; python_full_version < '3.8' + - greenlet>=1 ; extra == 'aiomysql' + - aiomysql>=0.2.0 ; extra == 'aiomysql' + - greenlet>=1 ; extra == 'aioodbc' + - aioodbc ; extra == 'aioodbc' + - greenlet>=1 ; extra == 'aiosqlite' + - aiosqlite ; extra == 'aiosqlite' + - typing-extensions!=3.10.0.1 ; extra == 'aiosqlite' + - greenlet>=1 ; extra == 'asyncio' + - greenlet>=1 ; extra == 'asyncmy' + - asyncmy>=0.2.3,!=0.2.4,!=0.2.6 ; extra == 'asyncmy' + - mariadb>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10 ; extra == 'mariadb-connector' + - pyodbc ; extra == 'mssql' + - pymssql ; extra == 'mssql-pymssql' + - pyodbc ; extra == 'mssql-pyodbc' + - mypy>=0.910 ; extra == 'mypy' + - mysqlclient>=1.4.0 ; extra == 'mysql' + - mysql-connector-python ; extra == 'mysql-connector' + - cx-oracle>=8 ; extra == 'oracle' + - oracledb>=1.0.1 ; extra == 'oracle-oracledb' + - psycopg2>=2.7 ; extra == 'postgresql' + - greenlet>=1 ; extra == 'postgresql-asyncpg' + - asyncpg ; extra == 'postgresql-asyncpg' + - pg8000>=1.29.1 ; extra == 'postgresql-pg8000' + - psycopg>=3.0.7 ; extra == 'postgresql-psycopg' + - psycopg2-binary ; extra == 'postgresql-psycopg2binary' + - psycopg2cffi ; extra == 'postgresql-psycopg2cffi' + - psycopg[binary]>=3.0.7 ; extra == 'postgresql-psycopgbinary' + - pymysql ; extra == 'pymysql' + - sqlcipher3-binary ; extra == 'sqlcipher' + requires_python: '>=3.7' - pypi: https://files.pythonhosted.org/packages/b6/c5/7ae467eeddb57260c8ce17a3a09f9f5edba35820fc022d7c55b7decd5d3a/starlette-0.44.0-py3-none-any.whl name: starlette version: 0.44.0 diff --git a/pixi.toml b/pixi.toml index 6a0a0a5..e42a3cf 100644 --- a/pixi.toml +++ b/pixi.toml @@ -28,3 +28,4 @@ fastapi = ">=0.124.4, <0.125" streamlit = ">=1.40.1, <2" httpx = ">=0.28.1, <0.29" uvicorn = ">=0.33.0, <0.34" +optuna = ">=4.5.0, <5" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..130ef26 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +# 严格遵循 pixi.toml 的依赖版本 +# 注意: lnp_ml 本地包在 Dockerfile 中单独安装 + +# conda dependencies (in pixi.toml [dependencies]) +loguru +tqdm +typer + +# pypi dependencies (in pixi.toml [pypi-dependencies]) +chemprop==1.7.0 +setuptools +pandas>=2.0.3,<3 +openpyxl>=3.1.5,<4 +python-dotenv>=1.0.1,<2 +pyarrow>=17.0.0,<18 +fastparquet>=2024.2.0,<2025 +fastapi>=0.124.4,<0.125 +streamlit>=1.40.1,<2 +httpx>=0.28.1,<0.29 +uvicorn>=0.33.0,<0.34 +optuna>=4.5.0,<5 \ No newline at end of file