mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 01:27:00 +08:00
Update models and UI
This commit is contained in:
parent
3f33f9d233
commit
a9392aa780
63
Dockerfile
Normal file
63
Dockerfile
Normal file
@ -0,0 +1,63 @@
|
||||
# LNP-ML Docker Image
|
||||
# 多阶段构建,支持 API 和 Streamlit 两种服务
|
||||
|
||||
FROM python:3.8-slim AS base
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PIP_NO_CACHE_DIR=1 \
|
||||
PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
libxrender1 \
|
||||
libxext6 \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 复制依赖文件
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装 Python 依赖
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 复制项目代码
|
||||
COPY pyproject.toml .
|
||||
COPY README.md .
|
||||
COPY LICENSE .
|
||||
COPY lnp_ml/ ./lnp_ml/
|
||||
COPY app/ ./app/
|
||||
|
||||
# 安装项目包
|
||||
RUN pip install -e .
|
||||
|
||||
# 复制模型文件
|
||||
COPY models/final/ ./models/final/
|
||||
|
||||
# ============ API 服务 ============
|
||||
FROM base AS api
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
ENV MODEL_PATH=/app/models/final/model.pt
|
||||
|
||||
CMD ["uvicorn", "app.api:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
# ============ Streamlit 服务 ============
|
||||
FROM base AS streamlit
|
||||
|
||||
EXPOSE 8501
|
||||
|
||||
# Streamlit 配置
|
||||
ENV STREAMLIT_SERVER_PORT=8501 \
|
||||
STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
|
||||
STREAMLIT_SERVER_HEADLESS=true \
|
||||
STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
|
||||
|
||||
CMD ["streamlit", "run", "app/app.py"]
|
||||
|
||||
80
Makefile
80
Makefile
@ -164,6 +164,44 @@ test_cv: requirements
|
||||
tune: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
# ============ 嵌套 CV + Optuna 调参(StratifiedKFold + 类权重) ============
|
||||
# 通用参数:
|
||||
# SEED: 随机种子 (默认: 42)
|
||||
# N_TRIALS: Optuna 试验数 (默认: 20)
|
||||
# EPOCHS_PER_TRIAL: 每个试验的最大 epoch (默认: 30)
|
||||
# MIN_STRATUM_COUNT: 复合分层标签的最小样本数 (默认: 5)
|
||||
# OUTPUT_DIR: 输出目录 (根据命令有不同默认值)
|
||||
# INIT_PRETRAIN: 预训练权重路径 (默认: models/pretrain_delivery.pt)
|
||||
|
||||
SEED_FLAG = $(if $(SEED),--seed $(SEED),)
|
||||
N_TRIALS_FLAG = $(if $(N_TRIALS),--n-trials $(N_TRIALS),)
|
||||
EPOCHS_PER_TRIAL_FLAG = $(if $(EPOCHS_PER_TRIAL),--epochs-per-trial $(EPOCHS_PER_TRIAL),)
|
||||
MIN_STRATUM_FLAG = $(if $(MIN_STRATUM_COUNT),--min-stratum-count $(MIN_STRATUM_COUNT),)
|
||||
OUTPUT_DIR_FLAG = $(if $(OUTPUT_DIR),--output-dir $(OUTPUT_DIR),)
|
||||
USE_SWA_FLAG = $(if $(USE_SWA),--use-swa,)
|
||||
# 默认使用预训练权重,设置 NO_PRETRAIN=1 可禁用
|
||||
INIT_PRETRAIN_FLAG = $(if $(NO_PRETRAIN),,--init-from-pretrain $(or $(INIT_PRETRAIN),models/pretrain_delivery.pt))
|
||||
|
||||
## Nested CV with Optuna: outer 5-fold (test) + inner 3-fold (tune)
|
||||
## 用于模型评估:外层 5-fold 产生无偏性能估计,内层 3-fold 做超参搜索
|
||||
## 默认加载 models/pretrain_delivery.pt 预训练权重,使用 NO_PRETRAIN=1 禁用
|
||||
## 使用示例: make nested_cv_tune DEVICE=cuda N_TRIALS=30
|
||||
.PHONY: nested_cv_tune
|
||||
nested_cv_tune: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.nested_cv_optuna \
|
||||
$(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \
|
||||
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG)
|
||||
|
||||
## Final training with Optuna: 3-fold CV tune + full data train
|
||||
## 用于最终模型训练:3-fold 调参后用全量数据训练(无 early-stop)
|
||||
## 默认加载 models/pretrain_delivery.pt 预训练权重,使用 NO_PRETRAIN=1 禁用
|
||||
## 使用示例: make final_optuna DEVICE=cuda N_TRIALS=30 USE_SWA=1
|
||||
.PHONY: final_optuna
|
||||
final_optuna: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.final_train_optuna_cv \
|
||||
$(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \
|
||||
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG)
|
||||
|
||||
## Run predictions
|
||||
.PHONY: predict
|
||||
predict: requirements
|
||||
@ -200,6 +238,48 @@ serve:
|
||||
@echo "然后访问: http://localhost:8501"
|
||||
|
||||
|
||||
#################################################################################
|
||||
# DOCKER COMMANDS #
|
||||
#################################################################################
|
||||
|
||||
## Build Docker images
|
||||
.PHONY: docker-build
|
||||
docker-build:
|
||||
docker compose build
|
||||
|
||||
## Start all services with Docker Compose
|
||||
.PHONY: docker-up
|
||||
docker-up:
|
||||
docker compose up -d
|
||||
|
||||
## Stop all Docker services
|
||||
.PHONY: docker-down
|
||||
docker-down:
|
||||
docker compose down
|
||||
|
||||
## View Docker logs
|
||||
.PHONY: docker-logs
|
||||
docker-logs:
|
||||
docker compose logs -f
|
||||
|
||||
## Build and start all services
|
||||
.PHONY: docker-serve
|
||||
docker-serve: docker-build docker-up
|
||||
@echo ""
|
||||
@echo "🚀 服务已启动!"
|
||||
@echo " - API: http://localhost:8000"
|
||||
@echo " - Web 应用: http://localhost:8501"
|
||||
@echo ""
|
||||
@echo "查看日志: make docker-logs"
|
||||
@echo "停止服务: make docker-down"
|
||||
|
||||
## Clean Docker resources (images, volumes, etc.)
|
||||
.PHONY: docker-clean
|
||||
docker-clean:
|
||||
docker compose down -v --rmi local
|
||||
docker system prune -f
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Self Documenting Commands #
|
||||
#################################################################################
|
||||
|
||||
15
app/SCORE.md
Normal file
15
app/SCORE.md
Normal file
@ -0,0 +1,15 @@
|
||||
## regression
|
||||
biodistribution(selected organ only): score = y * weight, where weight=0.3
|
||||
quantified_delivery: score = (y-min)/(max-min)*weight, where weight=0.25, (min=-0.798559291, max=4.497814051056962) when route_of_administration=intravenous, (min=-0.794912427, max=10.220042980012716) when route_of_administration=intramuscular
|
||||
size: score = 0 * weight if y<60, 1 * weight if 60<=y<=150, 0 * weight if y>150, where weight=0.05
|
||||
|
||||
## classification
|
||||
encapsulation_efficiency_0: score = weight, where weight=0
|
||||
encapsulation_efficiency_1: score = weight, where weight=0.02
|
||||
encapsulation_efficiency_2: score = weight, where weight=0.08
|
||||
pdi_0: score = weight, where weight=0.08
|
||||
pdi_1: score = weight, where weight=0.02
|
||||
pdi_2: score = weight, where weight=0
|
||||
pdi_3: score = weight, where weight=0
|
||||
toxicity_0: score=weight, where weight=0.2
|
||||
toxicity_1: score=weight, where weight=0
|
||||
125
app/api.py
125
app/api.py
@ -23,23 +23,87 @@ from app.optimize import (
|
||||
format_results,
|
||||
AVAILABLE_ORGANS,
|
||||
TARGET_BIODIST,
|
||||
CompRanges,
|
||||
ScoringWeights,
|
||||
)
|
||||
|
||||
|
||||
# ============ Pydantic Models ============
|
||||
|
||||
class CompRangesRequest(BaseModel):
|
||||
"""组分范围配置"""
|
||||
weight_ratio_min: float = Field(default=0.05, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最小值")
|
||||
weight_ratio_max: float = Field(default=0.30, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最大值")
|
||||
cationic_mol_min: float = Field(default=0.05, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最小值")
|
||||
cationic_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最大值")
|
||||
phospholipid_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="磷脂 mol 比例最小值")
|
||||
phospholipid_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="磷脂 mol 比例最大值")
|
||||
cholesterol_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="胆固醇 mol 比例最小值")
|
||||
cholesterol_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="胆固醇 mol 比例最大值")
|
||||
peg_mol_min: float = Field(default=0.00, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最小值")
|
||||
peg_mol_max: float = Field(default=0.05, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最大值")
|
||||
|
||||
def to_comp_ranges(self) -> CompRanges:
|
||||
"""转换为 CompRanges 对象"""
|
||||
return CompRanges(
|
||||
weight_ratio_min=self.weight_ratio_min,
|
||||
weight_ratio_max=self.weight_ratio_max,
|
||||
cationic_mol_min=self.cationic_mol_min,
|
||||
cationic_mol_max=self.cationic_mol_max,
|
||||
phospholipid_mol_min=self.phospholipid_mol_min,
|
||||
phospholipid_mol_max=self.phospholipid_mol_max,
|
||||
cholesterol_mol_min=self.cholesterol_mol_min,
|
||||
cholesterol_mol_max=self.cholesterol_mol_max,
|
||||
peg_mol_min=self.peg_mol_min,
|
||||
peg_mol_max=self.peg_mol_max,
|
||||
)
|
||||
|
||||
|
||||
class ScoringWeightsRequest(BaseModel):
|
||||
"""评分权重配置"""
|
||||
biodist_weight: float = Field(default=1.0, ge=0.0, description="目标器官分布权重")
|
||||
delivery_weight: float = Field(default=0.0, ge=0.0, description="量化递送权重")
|
||||
size_weight: float = Field(default=0.0, ge=0.0, description="粒径权重 (80-150nm)")
|
||||
ee_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0], description="EE 分类权重 [class0, class1, class2]")
|
||||
pdi_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0, 0.0], description="PDI 分类权重 [class0, class1, class2, class3]")
|
||||
toxic_class_weights: List[float] = Field(default=[0.0, 0.0], description="毒性分类权重 [无毒, 有毒]")
|
||||
|
||||
def to_scoring_weights(self) -> ScoringWeights:
|
||||
"""转换为 ScoringWeights 对象"""
|
||||
return ScoringWeights(
|
||||
biodist_weight=self.biodist_weight,
|
||||
delivery_weight=self.delivery_weight,
|
||||
size_weight=self.size_weight,
|
||||
ee_class_weights=self.ee_class_weights,
|
||||
pdi_class_weights=self.pdi_class_weights,
|
||||
toxic_class_weights=self.toxic_class_weights,
|
||||
)
|
||||
|
||||
|
||||
class OptimizeRequest(BaseModel):
|
||||
"""优化请求"""
|
||||
smiles: str = Field(..., description="Cationic lipid SMILES string")
|
||||
organ: str = Field(..., description="Target organ for optimization")
|
||||
top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations")
|
||||
top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations to return")
|
||||
num_seeds: Optional[int] = Field(default=None, ge=1, le=500, description="Number of seed points from first iteration (default: top_k * 5)")
|
||||
top_per_seed: int = Field(default=1, ge=1, le=10, description="Number of local best to keep per seed in refinement")
|
||||
step_sizes: Optional[List[float]] = Field(default=None, description="Step sizes for each iteration (default: [0.10, 0.02, 0.01])")
|
||||
comp_ranges: Optional[CompRangesRequest] = Field(default=None, description="组分范围配置(默认使用标准范围)")
|
||||
routes: Optional[List[str]] = Field(default=None, description="给药途径列表 (default: ['intravenous', 'intramuscular'])")
|
||||
scoring_weights: Optional[ScoringWeightsRequest] = Field(default=None, description="评分权重配置(默认仅按 biodist 排序)")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"smiles": "CC(C)NCCNC(C)C",
|
||||
"organ": "liver",
|
||||
"top_k": 20
|
||||
"top_k": 20,
|
||||
"num_seeds": None,
|
||||
"top_per_seed": 1,
|
||||
"step_sizes": None,
|
||||
"comp_ranges": None,
|
||||
"routes": None,
|
||||
"scoring_weights": None
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,6 +112,7 @@ class FormulationResult(BaseModel):
|
||||
"""单个配方结果"""
|
||||
rank: int
|
||||
target_biodist: float
|
||||
composite_score: Optional[float] = None # 综合评分
|
||||
cationic_lipid_to_mrna_ratio: float
|
||||
cationic_lipid_mol_ratio: float
|
||||
phospholipid_mol_ratio: float
|
||||
@ -56,6 +121,12 @@ class FormulationResult(BaseModel):
|
||||
helper_lipid: str
|
||||
route: str
|
||||
all_biodist: Dict[str, float]
|
||||
# 额外预测值
|
||||
quantified_delivery: Optional[float] = None
|
||||
size: Optional[float] = None
|
||||
pdi_class: Optional[int] = None # PDI 分类 (0: <0.2, 1: 0.2-0.3, 2: 0.3-0.4, 3: >0.4)
|
||||
ee_class: Optional[int] = None # EE 分类 (0: <80%, 1: 80-90%, 2: >90%)
|
||||
toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒)
|
||||
|
||||
|
||||
class OptimizeResponse(BaseModel):
|
||||
@ -187,25 +258,65 @@ async def optimize_formulation(request: OptimizeRequest):
|
||||
if not request.smiles or len(request.smiles.strip()) == 0:
|
||||
raise HTTPException(status_code=400, detail="SMILES string cannot be empty")
|
||||
|
||||
logger.info(f"Optimization request: organ={request.organ}, smiles={request.smiles[:50]}...")
|
||||
# 验证 routes
|
||||
valid_routes = ["intravenous", "intramuscular"]
|
||||
if request.routes is not None:
|
||||
for r in request.routes:
|
||||
if r not in valid_routes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid route: {r}. Available: {valid_routes}"
|
||||
)
|
||||
if len(request.routes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one route must be specified")
|
||||
|
||||
logger.info(f"Optimization request: organ={request.organ}, routes={request.routes}, smiles={request.smiles[:50]}...")
|
||||
|
||||
# 构建组分范围配置(在 try 块外验证,确保返回 400 而非 500)
|
||||
comp_ranges = None
|
||||
if request.comp_ranges is not None:
|
||||
comp_ranges = request.comp_ranges.to_comp_ranges()
|
||||
# 验证范围是否合理
|
||||
validation_error = comp_ranges.get_validation_error()
|
||||
if validation_error:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"组分范围配置无效: {validation_error}"
|
||||
)
|
||||
|
||||
# 构建评分权重配置
|
||||
scoring_weights = None
|
||||
if request.scoring_weights is not None:
|
||||
scoring_weights = request.scoring_weights.to_scoring_weights()
|
||||
|
||||
try:
|
||||
# 执行优化
|
||||
# 执行优化(层级搜索策略)
|
||||
results = optimize(
|
||||
smiles=request.smiles,
|
||||
organ=request.organ,
|
||||
model=state.model,
|
||||
device=state.device,
|
||||
top_k=request.top_k,
|
||||
num_seeds=request.num_seeds,
|
||||
top_per_seed=request.top_per_seed,
|
||||
step_sizes=request.step_sizes,
|
||||
comp_ranges=comp_ranges,
|
||||
routes=request.routes,
|
||||
scoring_weights=scoring_weights,
|
||||
batch_size=256,
|
||||
)
|
||||
|
||||
# 用于计算综合评分的权重
|
||||
from app.optimize import compute_formulation_score, DEFAULT_SCORING_WEIGHTS
|
||||
actual_scoring_weights = scoring_weights if scoring_weights is not None else DEFAULT_SCORING_WEIGHTS
|
||||
|
||||
# 转换结果
|
||||
formulations = []
|
||||
for i, f in enumerate(results):
|
||||
formulations.append(FormulationResult(
|
||||
rank=i + 1,
|
||||
target_biodist=f.get_biodist(request.organ),
|
||||
composite_score=compute_formulation_score(f, request.organ, actual_scoring_weights),
|
||||
cationic_lipid_to_mrna_ratio=f.cationic_lipid_to_mrna_ratio,
|
||||
cationic_lipid_mol_ratio=f.cationic_lipid_mol_ratio,
|
||||
phospholipid_mol_ratio=f.phospholipid_mol_ratio,
|
||||
@ -217,6 +328,12 @@ async def optimize_formulation(request: OptimizeRequest):
|
||||
col.replace("Biodistribution_", ""): f.biodist_predictions.get(col, 0.0)
|
||||
for col in TARGET_BIODIST
|
||||
},
|
||||
# 额外预测值
|
||||
quantified_delivery=f.quantified_delivery,
|
||||
size=f.size,
|
||||
pdi_class=f.pdi_class,
|
||||
ee_class=f.ee_class,
|
||||
toxic_class=f.toxic_class,
|
||||
))
|
||||
|
||||
logger.success(f"Optimization completed: {len(formulations)} formulations")
|
||||
|
||||
621
app/app.py
621
app/app.py
@ -3,9 +3,13 @@ Streamlit 配方优化交互界面
|
||||
|
||||
启动应用:
|
||||
streamlit run app/app.py
|
||||
|
||||
Docker 环境变量:
|
||||
API_URL: API 服务地址 (默认: http://localhost:8000)
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import httpx
|
||||
@ -14,7 +18,8 @@ import streamlit as st
|
||||
|
||||
# ============ 配置 ============
|
||||
|
||||
API_URL = "http://localhost:8000"
|
||||
# 从环境变量读取 API 地址,支持 Docker 环境
|
||||
API_URL = os.environ.get("API_URL", "http://localhost:8000")
|
||||
|
||||
AVAILABLE_ORGANS = [
|
||||
"liver",
|
||||
@ -27,13 +32,23 @@ AVAILABLE_ORGANS = [
|
||||
]
|
||||
|
||||
ORGAN_LABELS = {
|
||||
"liver": "🫀 肝脏 (Liver)",
|
||||
"spleen": "🟣 脾脏 (Spleen)",
|
||||
"lung": "🫁 肺 (Lung)",
|
||||
"heart": "❤️ 心脏 (Heart)",
|
||||
"kidney": "🫘 肾脏 (Kidney)",
|
||||
"muscle": "💪 肌肉 (Muscle)",
|
||||
"lymph_nodes": "🔵 淋巴结 (Lymph Nodes)",
|
||||
"liver": "肝脏 (Liver)",
|
||||
"spleen": "脾脏 (Spleen)",
|
||||
"lung": "肺 (Lung)",
|
||||
"heart": "心脏 (Heart)",
|
||||
"kidney": "肾脏 (Kidney)",
|
||||
"muscle": "肌肉 (Muscle)",
|
||||
"lymph_nodes": "淋巴结 (Lymph Nodes)",
|
||||
}
|
||||
|
||||
AVAILABLE_ROUTES = [
|
||||
"intravenous",
|
||||
"intramuscular",
|
||||
]
|
||||
|
||||
ROUTE_LABELS = {
|
||||
"intravenous": "静脉注射 (Intravenous)",
|
||||
"intramuscular": "肌肉注射 (Intramuscular)",
|
||||
}
|
||||
|
||||
# ============ 页面配置 ============
|
||||
@ -122,43 +137,107 @@ def check_api_status() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def call_optimize_api(smiles: str, organ: str, top_k: int = 20) -> dict:
|
||||
def call_optimize_api(
|
||||
smiles: str,
|
||||
organ: str,
|
||||
top_k: int = 20,
|
||||
num_seeds: int = None,
|
||||
top_per_seed: int = 1,
|
||||
step_sizes: list = None,
|
||||
comp_ranges: dict = None,
|
||||
routes: list = None,
|
||||
scoring_weights: dict = None,
|
||||
) -> dict:
|
||||
"""调用优化 API"""
|
||||
with httpx.Client(timeout=300) as client: # 5 分钟超时
|
||||
payload = {
|
||||
"smiles": smiles,
|
||||
"organ": organ,
|
||||
"top_k": top_k,
|
||||
"num_seeds": num_seeds,
|
||||
"top_per_seed": top_per_seed,
|
||||
"step_sizes": step_sizes,
|
||||
"comp_ranges": comp_ranges,
|
||||
"routes": routes,
|
||||
"scoring_weights": scoring_weights,
|
||||
}
|
||||
|
||||
with httpx.Client(timeout=600) as client: # 10 分钟超时(自定义参数可能需要更长时间)
|
||||
response = client.post(
|
||||
f"{API_URL}/optimize",
|
||||
json={
|
||||
"smiles": smiles,
|
||||
"organ": organ,
|
||||
"top_k": top_k,
|
||||
},
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def format_results_dataframe(results: dict) -> pd.DataFrame:
|
||||
# PDI 分类标签
|
||||
PDI_CLASS_LABELS = {
|
||||
0: "<0.2 (优)",
|
||||
1: "0.2-0.3 (良)",
|
||||
2: "0.3-0.4 (中)",
|
||||
3: ">0.4 (差)",
|
||||
}
|
||||
|
||||
# EE 分类标签
|
||||
EE_CLASS_LABELS = {
|
||||
0: "<50% (低)",
|
||||
1: "50-80% (中)",
|
||||
2: ">80% (高)",
|
||||
}
|
||||
|
||||
# 毒性分类标签
|
||||
TOXIC_CLASS_LABELS = {
|
||||
0: "无毒 ✓",
|
||||
1: "有毒 ⚠",
|
||||
}
|
||||
|
||||
|
||||
def format_results_dataframe(results: dict, smiles_label: str = None) -> pd.DataFrame:
|
||||
"""将 API 结果转换为 DataFrame"""
|
||||
formulations = results["formulations"]
|
||||
target_organ = results["target_organ"]
|
||||
|
||||
rows = []
|
||||
for f in formulations:
|
||||
row = {
|
||||
row = {}
|
||||
|
||||
# 如果有 SMILES 标签,添加到首列
|
||||
if smiles_label:
|
||||
row["SMILES"] = smiles_label
|
||||
|
||||
row.update({
|
||||
"排名": f["rank"],
|
||||
f"Biodist_{target_organ}": f"{f['target_biodist']:.4f}",
|
||||
"阳离子脂质/mRNA": f["cationic_lipid_to_mrna_ratio"],
|
||||
"阳离子脂质(mol)": f["cationic_lipid_mol_ratio"],
|
||||
"磷脂(mol)": f["phospholipid_mol_ratio"],
|
||||
"胆固醇(mol)": f["cholesterol_mol_ratio"],
|
||||
"PEG脂质(mol)": f["peg_lipid_mol_ratio"],
|
||||
})
|
||||
# 如果有综合评分,显示在排名后面
|
||||
if f.get("composite_score") is not None:
|
||||
row["综合评分"] = f"{f['composite_score']:.4f}"
|
||||
row.update({
|
||||
f"{target_organ}分布": f"{f['target_biodist']*100:.8f}%",
|
||||
"阳离子脂质/mRNA比例": f["cationic_lipid_to_mrna_ratio"],
|
||||
"阳离子脂质(mol)比例": f["cationic_lipid_mol_ratio"],
|
||||
"磷脂(mol)比例": f["phospholipid_mol_ratio"],
|
||||
"胆固醇(mol)比例": f["cholesterol_mol_ratio"],
|
||||
"PEG脂质(mol)比例": f["peg_lipid_mol_ratio"],
|
||||
"辅助脂质": f["helper_lipid"],
|
||||
"给药途径": f["route"],
|
||||
}
|
||||
})
|
||||
|
||||
# 添加额外预测值
|
||||
if f.get("quantified_delivery") is not None:
|
||||
row["量化递送"] = f"{f['quantified_delivery']:.4f}"
|
||||
if f.get("size") is not None:
|
||||
row["粒径(nm)"] = f"{f['size']:.1f}"
|
||||
if f.get("pdi_class") is not None:
|
||||
row["PDI"] = PDI_CLASS_LABELS.get(f["pdi_class"], str(f["pdi_class"]))
|
||||
if f.get("ee_class") is not None:
|
||||
row["包封率"] = EE_CLASS_LABELS.get(f["ee_class"], str(f["ee_class"]))
|
||||
if f.get("toxic_class") is not None:
|
||||
row["毒性"] = TOXIC_CLASS_LABELS.get(f["toxic_class"], str(f["toxic_class"]))
|
||||
|
||||
# 添加其他器官的 biodist
|
||||
for organ, value in f["all_biodist"].items():
|
||||
if organ != target_organ:
|
||||
row[f"Biodist_{organ}"] = f"{value:.4f}"
|
||||
row[f"{organ}分布"] = f"{value*100:.2f}%"
|
||||
rows.append(row)
|
||||
|
||||
return pd.DataFrame(rows)
|
||||
@ -184,7 +263,7 @@ def main():
|
||||
|
||||
# ========== 侧边栏 ==========
|
||||
with st.sidebar:
|
||||
st.header("⚙️ 参数设置")
|
||||
# st.header("⚙️ 参数设置")
|
||||
|
||||
# API 状态
|
||||
if api_online:
|
||||
@ -193,7 +272,7 @@ def main():
|
||||
st.error("🔴 API 服务离线")
|
||||
st.info("请先启动 API 服务:\n```\nuvicorn app.api:app --port 8000\n```")
|
||||
|
||||
st.divider()
|
||||
# st.divider()
|
||||
|
||||
# SMILES 输入
|
||||
st.subheader("🔬 分子结构")
|
||||
@ -201,23 +280,23 @@ def main():
|
||||
"输入阳离子脂质 SMILES",
|
||||
value="",
|
||||
height=100,
|
||||
placeholder="例如: CC(C)NCCNC(C)C",
|
||||
help="输入阳离子脂质的 SMILES 字符串",
|
||||
placeholder="例如: CC(C)NCCNC(C)C\n多条SMILES用英文逗号分隔: SMI1,SMI2,SMI3",
|
||||
help="输入阳离子脂质的 SMILES 字符串。支持多条 SMILES,用英文逗号 (,) 分隔",
|
||||
)
|
||||
|
||||
# 示例 SMILES
|
||||
with st.expander("📋 示例 SMILES"):
|
||||
example_smiles = {
|
||||
"DLin-MC3-DMA": "CC(C)=CCCC(C)=CCCC(C)=CCN(C)CCCCCCCCOC(=O)CCCCCCC/C=C\\CCCCCCCC",
|
||||
"简单胺": "CC(C)NCCNC(C)C",
|
||||
"长链胺": "CCCCCCCCCCCCNCCNCCCCCCCCCCCC",
|
||||
}
|
||||
for name, smi in example_smiles.items():
|
||||
if st.button(f"使用 {name}", key=f"example_{name}"):
|
||||
st.session_state["smiles_input"] = smi
|
||||
st.rerun()
|
||||
# with st.expander("📋 示例 SMILES"):
|
||||
# example_smiles = {
|
||||
# "DLin-MC3-DMA": "CC(C)=CCCC(C)=CCCC(C)=CCN(C)CCCCCCCCOC(=O)CCCCCCC/C=C\\CCCCCCCC",
|
||||
# "简单胺": "CC(C)NCCNC(C)C",
|
||||
# "长链胺": "CCCCCCCCCCCCNCCNCCCCCCCCCCCC",
|
||||
# }
|
||||
# for name, smi in example_smiles.items():
|
||||
# if st.button(f"使用 {name}", key=f"example_{name}"):
|
||||
# st.session_state["smiles_input"] = smi
|
||||
# st.rerun()
|
||||
|
||||
st.divider()
|
||||
# st.divider()
|
||||
|
||||
# 目标器官选择
|
||||
st.subheader("🎯 目标器官")
|
||||
@ -228,17 +307,226 @@ def main():
|
||||
index=0,
|
||||
)
|
||||
|
||||
st.divider()
|
||||
# 给药途径选择
|
||||
st.subheader("💉 给药途径")
|
||||
selected_routes = st.multiselect(
|
||||
"选择给药途径",
|
||||
options=AVAILABLE_ROUTES,
|
||||
default=AVAILABLE_ROUTES,
|
||||
format_func=lambda x: ROUTE_LABELS.get(x, x),
|
||||
help="选择要搜索的给药途径,可多选。至少选择一种。",
|
||||
)
|
||||
if not selected_routes:
|
||||
st.warning("⚠️ 请至少选择一种给药途径")
|
||||
|
||||
# 高级选项
|
||||
with st.expander("🔧 高级选项"):
|
||||
st.markdown("**输出设置**")
|
||||
top_k = st.slider(
|
||||
"返回配方数量",
|
||||
"返回配方数量 (top_k)",
|
||||
min_value=5,
|
||||
max_value=50,
|
||||
max_value=100,
|
||||
value=20,
|
||||
step=5,
|
||||
help="最终返回的最优配方数量",
|
||||
)
|
||||
|
||||
st.markdown("**搜索策略**")
|
||||
num_seeds = st.slider(
|
||||
"种子点数量 (num_seeds)",
|
||||
min_value=10,
|
||||
max_value=200,
|
||||
value=top_k * 5,
|
||||
step=10,
|
||||
help="第一轮迭代后保留的种子点数量,更多种子点意味着更广泛的搜索",
|
||||
)
|
||||
|
||||
top_per_seed = st.slider(
|
||||
"每个种子的局部最优数 (top_per_seed)",
|
||||
min_value=1,
|
||||
max_value=5,
|
||||
value=1,
|
||||
step=1,
|
||||
help="后续迭代中,每个种子点邻域保留的局部最优数量",
|
||||
)
|
||||
|
||||
st.markdown("**迭代步长与轮数**")
|
||||
use_custom_steps = st.checkbox(
|
||||
"自定义迭代步长",
|
||||
value=False,
|
||||
help="默认步长为 [0.10, 0.02, 0.01],共3轮逐步精细化搜索。将某轮步长设为0可减少迭代轮数。",
|
||||
)
|
||||
|
||||
if use_custom_steps:
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
step1 = st.number_input(
|
||||
"第1轮步长",
|
||||
min_value=0.01, max_value=0.20, value=0.10,
|
||||
step=0.01, format="%.2f",
|
||||
help="第1轮为全局粗搜索,步长必须大于0",
|
||||
)
|
||||
with col2:
|
||||
step2 = st.number_input(
|
||||
"第2轮步长",
|
||||
min_value=0.00, max_value=0.10, value=0.02,
|
||||
step=0.01, format="%.2f",
|
||||
help="设为0则只进行1轮搜索",
|
||||
)
|
||||
with col3:
|
||||
step3 = st.number_input(
|
||||
"第3轮步长",
|
||||
min_value=0.00, max_value=0.05, value=0.01,
|
||||
step=0.01, format="%.2f",
|
||||
help="设为0则只进行2轮搜索",
|
||||
)
|
||||
|
||||
# 根据步长值构建实际的 step_sizes 列表
|
||||
# step2 为 0 → 只保留 [step1](1轮)
|
||||
# step3 为 0 → 只保留 [step1, step2](2轮)
|
||||
# 都不为 0 → [step1, step2, step3](3轮)
|
||||
if step2 == 0.0:
|
||||
step_sizes = [step1]
|
||||
elif step3 == 0.0:
|
||||
step_sizes = [step1, step2]
|
||||
else:
|
||||
step_sizes = [step1, step2, step3]
|
||||
|
||||
# 显示实际迭代轮数提示
|
||||
st.caption(f"📌 实际迭代轮数: {len(step_sizes)} 轮,步长: {step_sizes}")
|
||||
else:
|
||||
step_sizes = None # 使用默认值
|
||||
|
||||
st.markdown("**组分范围限制**")
|
||||
use_custom_ranges = st.checkbox(
|
||||
"自定义组分取值范围",
|
||||
value=False,
|
||||
help="限制各组分的取值范围(mol 比例加起来仍为 100%)",
|
||||
)
|
||||
|
||||
if use_custom_ranges:
|
||||
st.caption("阳离子脂质/mRNA 重量比")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
weight_ratio_min = st.number_input("最小", min_value=0.01, max_value=0.50, value=0.05, step=0.01, format="%.2f", key="wr_min")
|
||||
with col2:
|
||||
weight_ratio_max = st.number_input("最大", min_value=0.01, max_value=0.50, value=0.30, step=0.01, format="%.2f", key="wr_max")
|
||||
|
||||
st.caption("阳离子脂质 mol 比例")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
cationic_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.05, step=0.05, format="%.2f", key="cat_min")
|
||||
with col2:
|
||||
cationic_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="cat_max")
|
||||
|
||||
st.caption("磷脂 mol 比例")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
phospholipid_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="phos_min")
|
||||
with col2:
|
||||
phospholipid_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="phos_max")
|
||||
|
||||
st.caption("胆固醇 mol 比例")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
cholesterol_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="chol_min")
|
||||
with col2:
|
||||
cholesterol_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="chol_max")
|
||||
|
||||
st.caption("PEG 脂质 mol 比例")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
peg_mol_min = st.number_input("最小", min_value=0.00, max_value=0.20, value=0.00, step=0.01, format="%.2f", key="peg_min")
|
||||
with col2:
|
||||
peg_mol_max = st.number_input("最大", min_value=0.00, max_value=0.20, value=0.05, step=0.01, format="%.2f", key="peg_max")
|
||||
|
||||
comp_ranges = {
|
||||
"weight_ratio_min": weight_ratio_min,
|
||||
"weight_ratio_max": weight_ratio_max,
|
||||
"cationic_mol_min": cationic_mol_min,
|
||||
"cationic_mol_max": cationic_mol_max,
|
||||
"phospholipid_mol_min": phospholipid_mol_min,
|
||||
"phospholipid_mol_max": phospholipid_mol_max,
|
||||
"cholesterol_mol_min": cholesterol_mol_min,
|
||||
"cholesterol_mol_max": cholesterol_mol_max,
|
||||
"peg_mol_min": peg_mol_min,
|
||||
"peg_mol_max": peg_mol_max,
|
||||
}
|
||||
|
||||
# 简单验证
|
||||
min_sum = cationic_mol_min + phospholipid_mol_min + cholesterol_mol_min + peg_mol_min
|
||||
max_sum = cationic_mol_max + phospholipid_mol_max + cholesterol_mol_max + peg_mol_max
|
||||
if min_sum > 1.0 or max_sum < 1.0:
|
||||
st.warning("⚠️ 当前范围设置可能无法生成有效配方(mol 比例需加起来为 100%)")
|
||||
else:
|
||||
comp_ranges = None # 使用默认值
|
||||
|
||||
st.markdown("**评分/排序权重**")
|
||||
use_custom_scoring = st.checkbox(
|
||||
"自定义评分权重",
|
||||
value=False,
|
||||
help="默认仅按目标器官分布排序。开启后可自定义多目标加权评分,总分 = 各项score之和。",
|
||||
)
|
||||
|
||||
if use_custom_scoring:
|
||||
st.caption("**回归任务权重**")
|
||||
|
||||
sw_biodist = st.number_input(
|
||||
"器官分布 (Biodistribution)",
|
||||
min_value=0.00, max_value=10.00, value=0.30,
|
||||
step=0.05, format="%.2f", key="sw_biodist",
|
||||
help="score = biodist_value × weight",
|
||||
)
|
||||
sw_delivery = st.number_input(
|
||||
"量化递送 (Quantified Delivery)",
|
||||
min_value=0.00, max_value=10.00, value=0.25,
|
||||
step=0.05, format="%.2f", key="sw_delivery",
|
||||
help="score = normalize(delivery, route) × weight",
|
||||
)
|
||||
sw_size = st.number_input(
|
||||
"粒径 (Size, 80-150nm)",
|
||||
min_value=0.00, max_value=10.00, value=0.05,
|
||||
step=0.05, format="%.2f", key="sw_size",
|
||||
help="score = (1 if 60≤size≤150 else 0) × weight",
|
||||
)
|
||||
|
||||
st.caption("**包封率 (EE) 分类权重**")
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
sw_ee0 = st.number_input("<50% (低)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_ee0")
|
||||
with col2:
|
||||
sw_ee1 = st.number_input("50-80% (中)", min_value=0.00, max_value=1.00, value=0.02, step=0.01, format="%.2f", key="sw_ee1")
|
||||
with col3:
|
||||
sw_ee2 = st.number_input(">80% (高)", min_value=0.00, max_value=1.00, value=0.08, step=0.01, format="%.2f", key="sw_ee2")
|
||||
|
||||
st.caption("**PDI 分类权重**")
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
sw_pdi0 = st.number_input("<0.2 (优)", min_value=0.00, max_value=1.00, value=0.08, step=0.01, format="%.2f", key="sw_pdi0")
|
||||
with col2:
|
||||
sw_pdi1 = st.number_input("0.2-0.3 (良)", min_value=0.00, max_value=1.00, value=0.02, step=0.01, format="%.2f", key="sw_pdi1")
|
||||
with col3:
|
||||
sw_pdi2 = st.number_input("0.3-0.4 (中)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_pdi2")
|
||||
with col4:
|
||||
sw_pdi3 = st.number_input(">0.4 (差)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_pdi3")
|
||||
|
||||
st.caption("**毒性分类权重**")
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
sw_toxic0 = st.number_input("无毒", min_value=0.00, max_value=1.00, value=0.20, step=0.05, format="%.2f", key="sw_toxic0")
|
||||
with col2:
|
||||
sw_toxic1 = st.number_input("有毒", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="sw_toxic1")
|
||||
|
||||
scoring_weights = {
|
||||
"biodist_weight": sw_biodist,
|
||||
"delivery_weight": sw_delivery,
|
||||
"size_weight": sw_size,
|
||||
"ee_class_weights": [sw_ee0, sw_ee1, sw_ee2],
|
||||
"pdi_class_weights": [sw_pdi0, sw_pdi1, sw_pdi2, sw_pdi3],
|
||||
"toxic_class_weights": [sw_toxic0, sw_toxic1],
|
||||
}
|
||||
else:
|
||||
scoring_weights = None # 使用默认值(仅按 biodist 排序)
|
||||
|
||||
st.divider()
|
||||
|
||||
@ -247,7 +535,7 @@ def main():
|
||||
"🚀 开始配方优选",
|
||||
type="primary",
|
||||
use_container_width=True,
|
||||
disabled=not api_online or not smiles_input.strip(),
|
||||
disabled=not api_online or not smiles_input.strip() or not selected_routes,
|
||||
)
|
||||
|
||||
# ========== 主内容区 ==========
|
||||
@ -260,49 +548,125 @@ def main():
|
||||
|
||||
# 执行优化
|
||||
if optimize_button and smiles_input.strip():
|
||||
with st.spinner("🔄 正在优化配方,请稍候..."):
|
||||
try:
|
||||
results = call_optimize_api(
|
||||
smiles=smiles_input.strip(),
|
||||
organ=selected_organ,
|
||||
top_k=top_k,
|
||||
)
|
||||
st.session_state["results"] = results
|
||||
st.session_state["results_df"] = format_results_dataframe(results)
|
||||
st.session_state["smiles_used"] = smiles_input.strip()
|
||||
# 解析多条 SMILES(用逗号分隔)
|
||||
smiles_list = [s.strip() for s in smiles_input.split(",") if s.strip()]
|
||||
|
||||
if not smiles_list:
|
||||
st.error("❌ 请输入有效的 SMILES 字符串")
|
||||
else:
|
||||
is_multi_smiles = len(smiles_list) > 1
|
||||
all_results = []
|
||||
all_dfs = []
|
||||
errors = []
|
||||
|
||||
# 进度条
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
for idx, smiles in enumerate(smiles_list):
|
||||
status_text.text(f"🔄 正在优化 SMILES {idx + 1}/{len(smiles_list)}...")
|
||||
progress_bar.progress((idx) / len(smiles_list))
|
||||
|
||||
try:
|
||||
results = call_optimize_api(
|
||||
smiles=smiles,
|
||||
organ=selected_organ,
|
||||
top_k=top_k,
|
||||
num_seeds=num_seeds,
|
||||
top_per_seed=top_per_seed,
|
||||
step_sizes=step_sizes,
|
||||
comp_ranges=comp_ranges,
|
||||
routes=selected_routes,
|
||||
scoring_weights=scoring_weights,
|
||||
)
|
||||
all_results.append({"smiles": smiles, "results": results})
|
||||
|
||||
# 为多 SMILES 模式添加 SMILES 标签
|
||||
smiles_label = smiles[:30] + "..." if len(smiles) > 30 else smiles
|
||||
df = format_results_dataframe(results, smiles_label if is_multi_smiles else None)
|
||||
all_dfs.append(df)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
try:
|
||||
error_detail = e.response.json().get("detail", str(e))
|
||||
except:
|
||||
error_detail = str(e)
|
||||
errors.append(f"SMILES {idx + 1}: {error_detail}")
|
||||
except httpx.RequestError as e:
|
||||
errors.append(f"SMILES {idx + 1}: API 连接失败 - {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"SMILES {idx + 1}: {e}")
|
||||
|
||||
progress_bar.progress(1.0)
|
||||
status_text.empty()
|
||||
progress_bar.empty()
|
||||
|
||||
# 显示错误
|
||||
for err in errors:
|
||||
st.error(f"❌ {err}")
|
||||
|
||||
# 保存结果
|
||||
if all_results:
|
||||
st.session_state["results"] = all_results[0]["results"] if len(all_results) == 1 else all_results
|
||||
st.session_state["results_df"] = pd.concat(all_dfs, ignore_index=True) if all_dfs else None
|
||||
st.session_state["smiles_used"] = smiles_list
|
||||
st.session_state["organ_used"] = selected_organ
|
||||
st.success("✅ 优化完成!")
|
||||
except httpx.RequestError as e:
|
||||
st.error(f"❌ API 请求失败: {e}")
|
||||
except Exception as e:
|
||||
st.error(f"❌ 发生错误: {e}")
|
||||
st.session_state["is_multi_smiles"] = is_multi_smiles
|
||||
st.success(f"✅ 优化完成!成功处理 {len(all_results)}/{len(smiles_list)} 条 SMILES")
|
||||
|
||||
# 显示结果
|
||||
if st.session_state["results"] is not None:
|
||||
if st.session_state["results"] is not None and st.session_state["results_df"] is not None:
|
||||
results = st.session_state["results"]
|
||||
df = st.session_state["results_df"]
|
||||
is_multi_smiles = st.session_state.get("is_multi_smiles", False)
|
||||
|
||||
# 结果概览
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric(
|
||||
"目标器官",
|
||||
ORGAN_LABELS.get(results["target_organ"], results["target_organ"]).split(" ")[0],
|
||||
)
|
||||
|
||||
with col2:
|
||||
best_score = results["formulations"][0]["target_biodist"]
|
||||
st.metric(
|
||||
"最优 Biodistribution",
|
||||
f"{best_score:.4f}",
|
||||
)
|
||||
|
||||
with col3:
|
||||
st.metric(
|
||||
"优选配方数",
|
||||
len(results["formulations"]),
|
||||
)
|
||||
if is_multi_smiles:
|
||||
# 多 SMILES 模式
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
# 获取 target_organ(从第一个结果)
|
||||
first_result = results[0]["results"] if isinstance(results, list) else results
|
||||
target_organ = first_result["target_organ"]
|
||||
st.metric(
|
||||
"目标器官",
|
||||
ORGAN_LABELS.get(target_organ, target_organ).split(" ")[0],
|
||||
)
|
||||
|
||||
with col2:
|
||||
st.metric(
|
||||
"SMILES 数量",
|
||||
len(results) if isinstance(results, list) else 1,
|
||||
)
|
||||
|
||||
with col3:
|
||||
st.metric(
|
||||
"总配方数",
|
||||
len(df),
|
||||
)
|
||||
else:
|
||||
# 单 SMILES 模式
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric(
|
||||
"目标器官",
|
||||
ORGAN_LABELS.get(results["target_organ"], results["target_organ"]).split(" ")[0],
|
||||
)
|
||||
|
||||
with col2:
|
||||
best_score = results["formulations"][0]["target_biodist"]
|
||||
st.metric(
|
||||
"最优分布",
|
||||
f"{best_score*100:.2f}%",
|
||||
)
|
||||
|
||||
with col3:
|
||||
st.metric(
|
||||
"优选配方数",
|
||||
len(results["formulations"]),
|
||||
)
|
||||
|
||||
st.divider()
|
||||
|
||||
@ -312,15 +676,26 @@ def main():
|
||||
# 导出按钮行
|
||||
col_export, col_spacer = st.columns([1, 4])
|
||||
with col_export:
|
||||
smiles_used = st.session_state.get("smiles_used", "")
|
||||
if isinstance(smiles_used, list):
|
||||
smiles_used = ",".join(smiles_used)
|
||||
|
||||
csv_content = create_export_csv(
|
||||
df,
|
||||
st.session_state.get("smiles_used", ""),
|
||||
smiles_used,
|
||||
st.session_state.get("organ_used", ""),
|
||||
)
|
||||
|
||||
# 获取 target_organ
|
||||
if is_multi_smiles:
|
||||
target_organ = results[0]["results"]["target_organ"] if isinstance(results, list) else results["target_organ"]
|
||||
else:
|
||||
target_organ = results["target_organ"]
|
||||
|
||||
st.download_button(
|
||||
label="📥 导出 CSV",
|
||||
data=csv_content,
|
||||
file_name=f"lnp_optimization_{results['target_organ']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
||||
file_name=f"lnp_optimization_{target_organ}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
||||
mime="text/csv",
|
||||
)
|
||||
|
||||
@ -333,61 +708,61 @@ def main():
|
||||
)
|
||||
|
||||
# 详细信息
|
||||
with st.expander("🔍 查看最优配方详情"):
|
||||
best = results["formulations"][0]
|
||||
# with st.expander("🔍 查看最优配方详情"):
|
||||
# best = results["formulations"][0]
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
# col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.markdown("**配方参数**")
|
||||
st.json({
|
||||
"阳离子脂质/mRNA 比例": best["cationic_lipid_to_mrna_ratio"],
|
||||
"阳离子脂质 (mol%)": best["cationic_lipid_mol_ratio"],
|
||||
"磷脂 (mol%)": best["phospholipid_mol_ratio"],
|
||||
"胆固醇 (mol%)": best["cholesterol_mol_ratio"],
|
||||
"PEG 脂质 (mol%)": best["peg_lipid_mol_ratio"],
|
||||
"辅助脂质": best["helper_lipid"],
|
||||
"给药途径": best["route"],
|
||||
})
|
||||
# with col1:
|
||||
# st.markdown("**配方参数**")
|
||||
# st.json({
|
||||
# "阳离子脂质/mRNA 比例": best["cationic_lipid_to_mrna_ratio"],
|
||||
# "阳离子脂质 (mol%)": best["cationic_lipid_mol_ratio"],
|
||||
# "磷脂 (mol%)": best["phospholipid_mol_ratio"],
|
||||
# "胆固醇 (mol%)": best["cholesterol_mol_ratio"],
|
||||
# "PEG 脂质 (mol%)": best["peg_lipid_mol_ratio"],
|
||||
# "辅助脂质": best["helper_lipid"],
|
||||
# "给药途径": best["route"],
|
||||
# })
|
||||
|
||||
with col2:
|
||||
st.markdown("**各器官 Biodistribution 预测**")
|
||||
biodist_df = pd.DataFrame([
|
||||
{"器官": ORGAN_LABELS.get(k, k), "Biodistribution": f"{v:.4f}"}
|
||||
for k, v in best["all_biodist"].items()
|
||||
])
|
||||
st.dataframe(biodist_df, hide_index=True, use_container_width=True)
|
||||
# with col2:
|
||||
# st.markdown("**各器官 Biodistribution 预测**")
|
||||
# biodist_df = pd.DataFrame([
|
||||
# {"器官": ORGAN_LABELS.get(k, k), "Biodistribution": f"{v:.4f}"}
|
||||
# for k, v in best["all_biodist"].items()
|
||||
# ])
|
||||
# st.dataframe(biodist_df, hide_index=True, use_container_width=True)
|
||||
|
||||
else:
|
||||
# 欢迎信息
|
||||
st.info("👈 请在左侧输入 SMILES 并选择目标器官,然后点击「开始配方优选」")
|
||||
|
||||
# 使用说明
|
||||
with st.expander("📖 使用说明"):
|
||||
st.markdown("""
|
||||
### 如何使用
|
||||
# with st.expander("📖 使用说明"):
|
||||
# st.markdown("""
|
||||
# ### 如何使用
|
||||
|
||||
1. **输入 SMILES**: 在左侧输入框中输入阳离子脂质的 SMILES 字符串
|
||||
2. **选择目标器官**: 选择您希望优化的器官靶向
|
||||
3. **点击优选**: 系统将自动搜索最优配方组合
|
||||
4. **查看结果**: 右侧将显示 Top-20 优选配方
|
||||
5. **导出数据**: 点击导出按钮将结果保存为 CSV 文件
|
||||
# 1. **输入 SMILES**: 在左侧输入框中输入阳离子脂质的 SMILES 字符串
|
||||
# 2. **选择目标器官**: 选择您希望优化的器官靶向
|
||||
# 3. **点击优选**: 系统将自动搜索最优配方组合
|
||||
# 4. **查看结果**: 右侧将显示 Top-20 优选配方
|
||||
# 5. **导出数据**: 点击导出按钮将结果保存为 CSV 文件
|
||||
|
||||
### 优化参数
|
||||
# ### 优化参数
|
||||
|
||||
系统会优化以下配方参数:
|
||||
- **阳离子脂质/mRNA 比例**: 0.05 - 0.30
|
||||
- **阳离子脂质 mol 比例**: 0.05 - 0.80
|
||||
- **磷脂 mol 比例**: 0.00 - 0.80
|
||||
- **胆固醇 mol 比例**: 0.00 - 0.80
|
||||
- **PEG 脂质 mol 比例**: 0.00 - 0.05
|
||||
- **辅助脂质**: DOPE / DSPC / DOTAP
|
||||
- **给药途径**: 静脉注射 / 肌肉注射
|
||||
# 系统会优化以下配方参数:
|
||||
# - **阳离子脂质/mRNA 比例**: 0.05 - 0.30
|
||||
# - **阳离子脂质 mol 比例**: 0.05 - 0.80
|
||||
# - **磷脂 mol 比例**: 0.00 - 0.80
|
||||
# - **胆固醇 mol 比例**: 0.00 - 0.80
|
||||
# - **PEG 脂质 mol 比例**: 0.00 - 0.05
|
||||
# - **辅助脂质**: DOPE / DSPC / DOTAP
|
||||
# - **给药途径**: 静脉注射 / 肌肉注射
|
||||
|
||||
### 约束条件
|
||||
# ### 约束条件
|
||||
|
||||
mol 比例之和 = 1 (阳离子脂质 + 磷脂 + 胆固醇 + PEG 脂质)
|
||||
""")
|
||||
# mol 比例之和 = 1 (阳离子脂质 + 磷脂 + 胆固醇 + PEG 脂质)
|
||||
# """)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
616
app/optimize.py
616
app/optimize.py
@ -41,14 +41,73 @@ app = typer.Typer()
|
||||
# 可用的目标器官
|
||||
AVAILABLE_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
|
||||
|
||||
# comp token 参数范围
|
||||
COMP_PARAM_RANGES = {
|
||||
"Cationic_Lipid_to_mRNA_weight_ratio": (0.05, 0.30),
|
||||
"Cationic_Lipid_Mol_Ratio": (0.05, 0.80),
|
||||
"Phospholipid_Mol_Ratio": (0.00, 0.80),
|
||||
"Cholesterol_Mol_Ratio": (0.00, 0.80),
|
||||
"PEG_Lipid_Mol_Ratio": (0.00, 0.05),
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class CompRanges:
|
||||
"""组分参数范围配置"""
|
||||
# 阳离子脂质/mRNA 重量比
|
||||
weight_ratio_min: float = 0.05
|
||||
weight_ratio_max: float = 0.30
|
||||
# 阳离子脂质 mol 比例
|
||||
cationic_mol_min: float = 0.05
|
||||
cationic_mol_max: float = 0.80
|
||||
# 磷脂 mol 比例
|
||||
phospholipid_mol_min: float = 0.00
|
||||
phospholipid_mol_max: float = 0.80
|
||||
# 胆固醇 mol 比例
|
||||
cholesterol_mol_min: float = 0.00
|
||||
cholesterol_mol_max: float = 0.80
|
||||
# PEG 脂质 mol 比例
|
||||
peg_mol_min: float = 0.00
|
||||
peg_mol_max: float = 0.05
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"weight_ratio": (self.weight_ratio_min, self.weight_ratio_max),
|
||||
"cationic_mol": (self.cationic_mol_min, self.cationic_mol_max),
|
||||
"phospholipid_mol": (self.phospholipid_mol_min, self.phospholipid_mol_max),
|
||||
"cholesterol_mol": (self.cholesterol_mol_min, self.cholesterol_mol_max),
|
||||
"peg_mol": (self.peg_mol_min, self.peg_mol_max),
|
||||
}
|
||||
|
||||
def get_validation_error(self) -> Optional[str]:
|
||||
"""
|
||||
验证范围是否合理,返回错误信息(如果有)。
|
||||
|
||||
Returns:
|
||||
错误信息字符串,如果验证通过则返回 None
|
||||
"""
|
||||
# 检查各范围是否有效(最小值不能大于最大值)
|
||||
if self.weight_ratio_min > self.weight_ratio_max:
|
||||
return f"阳离子脂质/mRNA重量比:最小值({self.weight_ratio_min})不能大于最大值({self.weight_ratio_max})"
|
||||
if self.cationic_mol_min > self.cationic_mol_max:
|
||||
return f"阳离子脂质mol比例:最小值({self.cationic_mol_min})不能大于最大值({self.cationic_mol_max})"
|
||||
if self.phospholipid_mol_min > self.phospholipid_mol_max:
|
||||
return f"磷脂mol比例:最小值({self.phospholipid_mol_min})不能大于最大值({self.phospholipid_mol_max})"
|
||||
if self.cholesterol_mol_min > self.cholesterol_mol_max:
|
||||
return f"胆固醇mol比例:最小值({self.cholesterol_mol_min})不能大于最大值({self.cholesterol_mol_max})"
|
||||
if self.peg_mol_min > self.peg_mol_max:
|
||||
return f"PEG脂质mol比例:最小值({self.peg_mol_min})不能大于最大值({self.peg_mol_max})"
|
||||
|
||||
# 检查 mol ratio 是否可能加起来为 1
|
||||
min_sum = self.cationic_mol_min + self.phospholipid_mol_min + self.cholesterol_mol_min + self.peg_mol_min
|
||||
max_sum = self.cationic_mol_max + self.phospholipid_mol_max + self.cholesterol_mol_max + self.peg_mol_max
|
||||
|
||||
if min_sum > 1.0:
|
||||
return f"mol比例最小值之和({min_sum:.2f})超过100%,无法生成有效配方"
|
||||
if max_sum < 1.0:
|
||||
return f"mol比例最大值之和({max_sum:.2f})不足100%,无法生成有效配方"
|
||||
|
||||
return None
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证范围是否合理(至少存在一个可行解)"""
|
||||
return self.get_validation_error() is None
|
||||
|
||||
|
||||
# 默认组分范围
|
||||
DEFAULT_COMP_RANGES = CompRanges()
|
||||
|
||||
# 最小 step size
|
||||
MIN_STEP_SIZE = 0.01
|
||||
@ -56,12 +115,153 @@ MIN_STEP_SIZE = 0.01
|
||||
# 迭代策略:每个迭代的 step_size
|
||||
ITERATION_STEP_SIZES = [0.10, 0.02, 0.01]
|
||||
|
||||
# Helper lipid 选项
|
||||
HELPER_LIPID_OPTIONS = ["DOPE", "DSPC", "DOTAP"]
|
||||
# Helper lipid 选项(不包含 DOTAP)
|
||||
HELPER_LIPID_OPTIONS = ["DOPE", "DSPC"]
|
||||
|
||||
# Route of administration 选项
|
||||
ROUTE_OPTIONS = ["intravenous", "intramuscular"]
|
||||
|
||||
# quantified_delivery 归一化常量(按给药途径)
|
||||
DELIVERY_NORM = {
|
||||
"intravenous": {"min": -0.798559291, "max": 4.497814051056962},
|
||||
"intramuscular": {"min": -0.794912427, "max": 10.220042980012716},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoringWeights:
|
||||
"""
|
||||
评分权重配置。
|
||||
|
||||
总分 = biodist_score + delivery_score + size_score + ee_score + pdi_score + toxic_score
|
||||
各项计算方式参见 SCORE.md。
|
||||
默认值:仅按目标器官 biodistribution 排序(向后兼容)。
|
||||
"""
|
||||
# 回归任务权重
|
||||
biodist_weight: float = 1.0 # score = biodist_value * weight
|
||||
delivery_weight: float = 0.0 # score = normalize(delivery, route) * weight
|
||||
size_weight: float = 0.0 # score = (1 if 80<=size<=150 else 0) * weight
|
||||
# 分类任务:per-class 权重(预测为该类时,得分 = 对应权重)
|
||||
ee_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) # EE class 0, 1, 2
|
||||
pdi_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0, 0.0]) # PDI class 0, 1, 2, 3
|
||||
toxic_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0]) # Toxic class 0, 1
|
||||
|
||||
|
||||
# 默认评分权重(仅按 biodist 排序)
|
||||
DEFAULT_SCORING_WEIGHTS = ScoringWeights()
|
||||
|
||||
|
||||
def compute_formulation_score(
|
||||
f: 'Formulation',
|
||||
organ: str,
|
||||
weights: ScoringWeights,
|
||||
) -> float:
|
||||
"""
|
||||
计算单个 Formulation 的综合评分。
|
||||
|
||||
Args:
|
||||
f: 配方对象
|
||||
organ: 目标器官
|
||||
weights: 评分权重
|
||||
|
||||
Returns:
|
||||
综合评分
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 1. biodistribution
|
||||
score += f.get_biodist(organ) * weights.biodist_weight
|
||||
|
||||
# 2. quantified_delivery(按给药途径归一化到 [0, 1])
|
||||
if f.quantified_delivery is not None and weights.delivery_weight != 0:
|
||||
norm = DELIVERY_NORM.get(f.route, DELIVERY_NORM["intravenous"])
|
||||
d_range = norm["max"] - norm["min"]
|
||||
if d_range > 0:
|
||||
delivery_normalized = (f.quantified_delivery - norm["min"]) / d_range
|
||||
delivery_normalized = max(0.0, min(1.0, delivery_normalized))
|
||||
else:
|
||||
delivery_normalized = 0.0
|
||||
score += delivery_normalized * weights.delivery_weight
|
||||
|
||||
# 3. size(60-150nm 为理想范围)
|
||||
if f.size is not None and weights.size_weight != 0:
|
||||
if 60 <= f.size <= 150:
|
||||
score += 1.0 * weights.size_weight
|
||||
|
||||
# 4. EE 分类
|
||||
if f.ee_class is not None and 0 <= f.ee_class < len(weights.ee_class_weights):
|
||||
score += weights.ee_class_weights[f.ee_class]
|
||||
|
||||
# 5. PDI 分类
|
||||
if f.pdi_class is not None and 0 <= f.pdi_class < len(weights.pdi_class_weights):
|
||||
score += weights.pdi_class_weights[f.pdi_class]
|
||||
|
||||
# 6. 毒性分类
|
||||
if f.toxic_class is not None and 0 <= f.toxic_class < len(weights.toxic_class_weights):
|
||||
score += weights.toxic_class_weights[f.toxic_class]
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def compute_df_score(
|
||||
df: pd.DataFrame,
|
||||
organ: str,
|
||||
weights: ScoringWeights,
|
||||
) -> pd.Series:
|
||||
"""
|
||||
为 DataFrame 中的所有行计算综合评分。
|
||||
|
||||
Args:
|
||||
df: 包含预测结果的 DataFrame
|
||||
organ: 目标器官
|
||||
weights: 评分权重
|
||||
|
||||
Returns:
|
||||
评分 Series
|
||||
"""
|
||||
score = pd.Series(0.0, index=df.index)
|
||||
|
||||
# 1. biodistribution
|
||||
pred_col = f"pred_Biodistribution_{organ}"
|
||||
if pred_col in df.columns:
|
||||
score += df[pred_col] * weights.biodist_weight
|
||||
|
||||
# 2. quantified_delivery(按给药途径归一化)
|
||||
if weights.delivery_weight != 0 and "pred_delivery" in df.columns:
|
||||
for route_name, norm in DELIVERY_NORM.items():
|
||||
mask = df["_route"] == route_name
|
||||
if mask.any():
|
||||
d_range = norm["max"] - norm["min"]
|
||||
if d_range > 0:
|
||||
delivery_normalized = (df.loc[mask, "pred_delivery"] - norm["min"]) / d_range
|
||||
delivery_normalized = delivery_normalized.clip(0.0, 1.0)
|
||||
score.loc[mask] += delivery_normalized * weights.delivery_weight
|
||||
|
||||
# 3. size
|
||||
if weights.size_weight != 0 and "pred_size" in df.columns:
|
||||
size_ok = (df["pred_size"] >= 60) & (df["pred_size"] <= 150)
|
||||
score += size_ok.astype(float) * weights.size_weight
|
||||
|
||||
# 4. EE 分类
|
||||
if "pred_ee_class" in df.columns:
|
||||
for cls, w in enumerate(weights.ee_class_weights):
|
||||
if w != 0:
|
||||
score += (df["pred_ee_class"] == cls).astype(float) * w
|
||||
|
||||
# 5. PDI 分类
|
||||
if "pred_pdi_class" in df.columns:
|
||||
for cls, w in enumerate(weights.pdi_class_weights):
|
||||
if w != 0:
|
||||
score += (df["pred_pdi_class"] == cls).astype(float) * w
|
||||
|
||||
# 6. 毒性分类
|
||||
if "pred_toxic_class" in df.columns:
|
||||
for cls, w in enumerate(weights.toxic_class_weights):
|
||||
if w != 0:
|
||||
score += (df["pred_toxic_class"] == cls).astype(float) * w
|
||||
|
||||
return score
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formulation:
|
||||
@ -77,6 +277,12 @@ class Formulation:
|
||||
route: str = "intravenous"
|
||||
# 预测结果(填充后)
|
||||
biodist_predictions: Dict[str, float] = field(default_factory=dict)
|
||||
# 额外预测值
|
||||
quantified_delivery: Optional[float] = None
|
||||
size: Optional[float] = None
|
||||
pdi_class: Optional[int] = None # PDI 分类 (0-3)
|
||||
ee_class: Optional[int] = None # EE 分类 (0-2)
|
||||
toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
@ -94,6 +300,18 @@ class Formulation:
|
||||
"""获取指定器官的 biodistribution 预测值"""
|
||||
col = f"Biodistribution_{organ}"
|
||||
return self.biodist_predictions.get(col, 0.0)
|
||||
|
||||
def unique_key(self) -> tuple:
|
||||
"""生成唯一标识键,用于去重"""
|
||||
return (
|
||||
round(self.cationic_lipid_to_mrna_ratio, 4),
|
||||
round(self.cationic_lipid_mol_ratio, 4),
|
||||
round(self.phospholipid_mol_ratio, 4),
|
||||
round(self.cholesterol_mol_ratio, 4),
|
||||
round(self.peg_lipid_mol_ratio, 4),
|
||||
self.helper_lipid,
|
||||
self.route,
|
||||
)
|
||||
|
||||
|
||||
def generate_grid_values(
|
||||
@ -124,20 +342,38 @@ def generate_grid_values(
|
||||
return sorted(set(values))
|
||||
|
||||
|
||||
def generate_initial_grid(step_size: float) -> List[Tuple[float, float, float, float, float]]:
|
||||
def generate_initial_grid(
|
||||
step_size: float,
|
||||
comp_ranges: CompRanges = None,
|
||||
) -> List[Tuple[float, float, float, float, float]]:
|
||||
"""
|
||||
生成初始搜索网格(满足 mol ratio 和为 1 的约束)。
|
||||
|
||||
Args:
|
||||
step_size: 搜索步长
|
||||
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES)
|
||||
|
||||
Returns:
|
||||
List of (cationic_ratio, cationic_mol, phospholipid_mol, cholesterol_mol, peg_mol)
|
||||
"""
|
||||
if comp_ranges is None:
|
||||
comp_ranges = DEFAULT_COMP_RANGES
|
||||
|
||||
grid = []
|
||||
|
||||
# Cationic_Lipid_to_mRNA_weight_ratio
|
||||
weight_ratios = np.arange(0.05, 0.31, step_size)
|
||||
weight_ratios = np.arange(
|
||||
comp_ranges.weight_ratio_min,
|
||||
comp_ranges.weight_ratio_max + 0.001,
|
||||
step_size
|
||||
)
|
||||
|
||||
# PEG: 单独处理,范围很小
|
||||
peg_values = np.arange(0.00, 0.06, MIN_STEP_SIZE) # PEG 始终用 0.01 步长
|
||||
# PEG: 单独处理,范围很小,始终用最小步长
|
||||
peg_values = np.arange(
|
||||
comp_ranges.peg_mol_min,
|
||||
comp_ranges.peg_mol_max + 0.001,
|
||||
MIN_STEP_SIZE
|
||||
)
|
||||
|
||||
# 其他三个 mol ratio 需要满足和为 1 - PEG
|
||||
mol_step = step_size
|
||||
@ -146,11 +382,13 @@ def generate_initial_grid(step_size: float) -> List[Tuple[float, float, float, f
|
||||
for peg in peg_values:
|
||||
remaining = 1.0 - peg
|
||||
# 生成满足约束的组合
|
||||
for cationic_mol in np.arange(0.05, min(0.81, remaining + 0.001), mol_step):
|
||||
for phospholipid_mol in np.arange(0.00, min(0.81, remaining - cationic_mol + 0.001), mol_step):
|
||||
cationic_max = min(comp_ranges.cationic_mol_max, remaining) + 0.001
|
||||
for cationic_mol in np.arange(comp_ranges.cationic_mol_min, cationic_max, mol_step):
|
||||
phospholipid_max = min(comp_ranges.phospholipid_mol_max, remaining - cationic_mol) + 0.001
|
||||
for phospholipid_mol in np.arange(comp_ranges.phospholipid_mol_min, phospholipid_max, mol_step):
|
||||
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
|
||||
# 检查约束
|
||||
if 0.00 <= cholesterol_mol <= 0.80:
|
||||
if (comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max):
|
||||
grid.append((
|
||||
round(weight_ratio, 4),
|
||||
round(cationic_mol, 4),
|
||||
@ -166,6 +404,7 @@ def generate_refined_grid(
|
||||
seeds: List[Formulation],
|
||||
step_size: float,
|
||||
radius: int = 2,
|
||||
comp_ranges: CompRanges = None,
|
||||
) -> List[Tuple[float, float, float, float, float]]:
|
||||
"""
|
||||
围绕种子点生成精细化网格。
|
||||
@ -174,29 +413,37 @@ def generate_refined_grid(
|
||||
seeds: 种子配方列表
|
||||
step_size: 步长
|
||||
radius: 扩展半径
|
||||
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES)
|
||||
|
||||
Returns:
|
||||
新的网格点列表
|
||||
"""
|
||||
if comp_ranges is None:
|
||||
comp_ranges = DEFAULT_COMP_RANGES
|
||||
|
||||
grid_set = set()
|
||||
|
||||
for seed in seeds:
|
||||
# Weight ratio
|
||||
weight_ratios = generate_grid_values(
|
||||
seed.cationic_lipid_to_mrna_ratio, step_size, 0.05, 0.30, radius
|
||||
seed.cationic_lipid_to_mrna_ratio, step_size,
|
||||
comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius
|
||||
)
|
||||
|
||||
# PEG (始终用最小步长)
|
||||
peg_values = generate_grid_values(
|
||||
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, 0.00, 0.05, radius
|
||||
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE,
|
||||
comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius
|
||||
)
|
||||
|
||||
# Mol ratios
|
||||
cationic_mols = generate_grid_values(
|
||||
seed.cationic_lipid_mol_ratio, step_size, 0.05, 0.80, radius
|
||||
seed.cationic_lipid_mol_ratio, step_size,
|
||||
comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius
|
||||
)
|
||||
phospholipid_mols = generate_grid_values(
|
||||
seed.phospholipid_mol_ratio, step_size, 0.00, 0.80, radius
|
||||
seed.phospholipid_mol_ratio, step_size,
|
||||
comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius
|
||||
)
|
||||
|
||||
for weight_ratio in weight_ratios:
|
||||
@ -206,10 +453,10 @@ def generate_refined_grid(
|
||||
for phospholipid_mol in phospholipid_mols:
|
||||
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
|
||||
# 检查约束
|
||||
if (0.05 <= cationic_mol <= 0.80 and
|
||||
0.00 <= phospholipid_mol <= 0.80 and
|
||||
0.00 <= cholesterol_mol <= 0.80 and
|
||||
0.00 <= peg <= 0.05):
|
||||
if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and
|
||||
comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and
|
||||
comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and
|
||||
comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max):
|
||||
grid_set.add((
|
||||
round(weight_ratio, 4),
|
||||
round(cationic_mol, 4),
|
||||
@ -283,14 +530,14 @@ def create_dataframe_from_formulations(
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def predict_biodist(
|
||||
def predict_all(
|
||||
model: torch.nn.Module,
|
||||
df: pd.DataFrame,
|
||||
device: torch.device,
|
||||
batch_size: int = 256,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
使用模型预测 biodistribution。
|
||||
使用模型预测所有输出(biodistribution、size、delivery、PDI、EE)。
|
||||
|
||||
Returns:
|
||||
添加了预测列的 DataFrame
|
||||
@ -301,6 +548,11 @@ def predict_biodist(
|
||||
)
|
||||
|
||||
all_biodist_preds = []
|
||||
all_size_preds = []
|
||||
all_delivery_preds = []
|
||||
all_pdi_preds = []
|
||||
all_ee_preds = []
|
||||
all_toxic_preds = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
@ -312,20 +564,53 @@ def predict_biodist(
|
||||
# biodist 输出是 softmax 后的概率分布 [B, 7]
|
||||
biodist_pred = outputs["biodist"].cpu().numpy()
|
||||
all_biodist_preds.append(biodist_pred)
|
||||
|
||||
# size 和 delivery 是回归值
|
||||
all_size_preds.append(outputs["size"].squeeze(-1).cpu().numpy())
|
||||
all_delivery_preds.append(outputs["delivery"].squeeze(-1).cpu().numpy())
|
||||
|
||||
# PDI、EE 和 toxic 是分类,取 argmax
|
||||
all_pdi_preds.append(outputs["pdi"].argmax(dim=-1).cpu().numpy())
|
||||
all_ee_preds.append(outputs["ee"].argmax(dim=-1).cpu().numpy())
|
||||
all_toxic_preds.append(outputs["toxic"].argmax(dim=-1).cpu().numpy())
|
||||
|
||||
biodist_preds = np.concatenate(all_biodist_preds, axis=0)
|
||||
size_preds = np.concatenate(all_size_preds, axis=0)
|
||||
delivery_preds = np.concatenate(all_delivery_preds, axis=0)
|
||||
pdi_preds = np.concatenate(all_pdi_preds, axis=0)
|
||||
ee_preds = np.concatenate(all_ee_preds, axis=0)
|
||||
toxic_preds = np.concatenate(all_toxic_preds, axis=0)
|
||||
|
||||
# 添加到 DataFrame
|
||||
for i, col in enumerate(TARGET_BIODIST):
|
||||
df[f"pred_{col}"] = biodist_preds[:, i]
|
||||
|
||||
# size 模型输出为 log(size),转换回真实粒径 (nm)
|
||||
df["pred_size"] = np.exp(size_preds)
|
||||
df["pred_delivery"] = delivery_preds
|
||||
df["pred_pdi_class"] = pdi_preds
|
||||
df["pred_ee_class"] = ee_preds
|
||||
df["pred_toxic_class"] = toxic_preds
|
||||
|
||||
return df
|
||||
|
||||
|
||||
# 保持向后兼容
|
||||
def predict_biodist(
|
||||
model: torch.nn.Module,
|
||||
df: pd.DataFrame,
|
||||
device: torch.device,
|
||||
batch_size: int = 256,
|
||||
) -> pd.DataFrame:
|
||||
"""向后兼容的别名"""
|
||||
return predict_all(model, df, device, batch_size)
|
||||
|
||||
|
||||
def select_top_k(
|
||||
df: pd.DataFrame,
|
||||
organ: str,
|
||||
k: int = 20,
|
||||
scoring_weights: Optional[ScoringWeights] = None,
|
||||
) -> List[Formulation]:
|
||||
"""
|
||||
选择 top-k 配方。
|
||||
@ -334,16 +619,18 @@ def select_top_k(
|
||||
df: 包含预测结果的 DataFrame
|
||||
organ: 目标器官
|
||||
k: 选择数量
|
||||
scoring_weights: 评分权重(默认仅按 biodist 排序)
|
||||
|
||||
Returns:
|
||||
Top-k 配方列表
|
||||
"""
|
||||
pred_col = f"pred_Biodistribution_{organ}"
|
||||
if pred_col not in df.columns:
|
||||
raise ValueError(f"Prediction column {pred_col} not found")
|
||||
if scoring_weights is None:
|
||||
scoring_weights = DEFAULT_SCORING_WEIGHTS
|
||||
|
||||
# 排序并去重
|
||||
df_sorted = df.sort_values(pred_col, ascending=False)
|
||||
# 计算综合评分并排序
|
||||
df = df.copy()
|
||||
df["_composite_score"] = compute_df_score(df, organ, scoring_weights)
|
||||
df_sorted = df.sort_values("_composite_score", ascending=False)
|
||||
|
||||
# 创建配方对象
|
||||
formulations = []
|
||||
@ -373,6 +660,12 @@ def select_top_k(
|
||||
biodist_predictions={
|
||||
col: row[f"pred_{col}"] for col in TARGET_BIODIST
|
||||
},
|
||||
# 额外预测值
|
||||
quantified_delivery=row.get("pred_delivery"),
|
||||
size=row.get("pred_size"),
|
||||
pdi_class=int(row.get("pred_pdi_class")) if row.get("pred_pdi_class") is not None else None,
|
||||
ee_class=int(row.get("pred_ee_class")) if row.get("pred_ee_class") is not None else None,
|
||||
toxic_class=int(row.get("pred_toxic_class")) if row.get("pred_toxic_class") is not None else None,
|
||||
)
|
||||
formulations.append(formulation)
|
||||
|
||||
@ -382,72 +675,242 @@ def select_top_k(
|
||||
return formulations
|
||||
|
||||
|
||||
def generate_single_seed_grid(
|
||||
seed: Formulation,
|
||||
step_size: float,
|
||||
radius: int = 2,
|
||||
comp_ranges: CompRanges = None,
|
||||
) -> List[Tuple[float, float, float, float, float]]:
|
||||
"""
|
||||
为单个种子点生成邻域网格。
|
||||
|
||||
Args:
|
||||
seed: 种子配方
|
||||
step_size: 步长
|
||||
radius: 扩展半径
|
||||
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES)
|
||||
|
||||
Returns:
|
||||
网格点列表
|
||||
"""
|
||||
if comp_ranges is None:
|
||||
comp_ranges = DEFAULT_COMP_RANGES
|
||||
|
||||
grid_set = set()
|
||||
|
||||
# Weight ratio
|
||||
weight_ratios = generate_grid_values(
|
||||
seed.cationic_lipid_to_mrna_ratio, step_size,
|
||||
comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius
|
||||
)
|
||||
|
||||
# PEG (始终用最小步长)
|
||||
peg_values = generate_grid_values(
|
||||
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE,
|
||||
comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius
|
||||
)
|
||||
|
||||
# Mol ratios
|
||||
cationic_mols = generate_grid_values(
|
||||
seed.cationic_lipid_mol_ratio, step_size,
|
||||
comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius
|
||||
)
|
||||
phospholipid_mols = generate_grid_values(
|
||||
seed.phospholipid_mol_ratio, step_size,
|
||||
comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius
|
||||
)
|
||||
|
||||
for weight_ratio in weight_ratios:
|
||||
for peg in peg_values:
|
||||
remaining = 1.0 - peg
|
||||
for cationic_mol in cationic_mols:
|
||||
for phospholipid_mol in phospholipid_mols:
|
||||
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
|
||||
# 检查约束
|
||||
if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and
|
||||
comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and
|
||||
comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and
|
||||
comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max):
|
||||
grid_set.add((
|
||||
round(weight_ratio, 4),
|
||||
round(cationic_mol, 4),
|
||||
round(phospholipid_mol, 4),
|
||||
round(cholesterol_mol, 4),
|
||||
round(peg, 4),
|
||||
))
|
||||
|
||||
return list(grid_set)
|
||||
|
||||
|
||||
def optimize(
|
||||
smiles: str,
|
||||
organ: str,
|
||||
model: torch.nn.Module,
|
||||
device: torch.device,
|
||||
top_k: int = 20,
|
||||
num_seeds: Optional[int] = None,
|
||||
top_per_seed: int = 1,
|
||||
step_sizes: Optional[List[float]] = None,
|
||||
comp_ranges: Optional[CompRanges] = None,
|
||||
routes: Optional[List[str]] = None,
|
||||
scoring_weights: Optional[ScoringWeights] = None,
|
||||
batch_size: int = 256,
|
||||
) -> List[Formulation]:
|
||||
"""
|
||||
执行配方优化。
|
||||
执行配方优化(层级搜索策略)。
|
||||
|
||||
采用层级搜索策略:
|
||||
1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点
|
||||
2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优
|
||||
3. 这样可以保持搜索的多样性,避免结果集中在单一区域
|
||||
|
||||
Args:
|
||||
smiles: SMILES 字符串
|
||||
organ: 目标器官
|
||||
model: 训练好的模型
|
||||
device: 计算设备
|
||||
top_k: 每轮保留的最优配方数
|
||||
top_k: 最终返回的最优配方数
|
||||
num_seeds: 第一次迭代后保留的种子点数量(默认为 top_k * 5)
|
||||
top_per_seed: 每个种子点的邻域搜索后保留的局部最优点数量
|
||||
step_sizes: 每轮迭代的步长列表(默认为 [0.10, 0.02, 0.01])
|
||||
comp_ranges: 组分范围配置(默认使用 DEFAULT_COMP_RANGES)
|
||||
routes: 给药途径列表(默认使用 ROUTE_OPTIONS)
|
||||
scoring_weights: 评分权重配置(默认仅按 biodist 排序)
|
||||
batch_size: 预测批次大小
|
||||
|
||||
Returns:
|
||||
最终 top-k 配方列表
|
||||
"""
|
||||
# 默认 num_seeds 为 top_k * 5
|
||||
if num_seeds is None:
|
||||
num_seeds = top_k * 5
|
||||
|
||||
# 默认步长
|
||||
if step_sizes is None:
|
||||
step_sizes = ITERATION_STEP_SIZES
|
||||
|
||||
# 默认组分范围
|
||||
if comp_ranges is None:
|
||||
comp_ranges = DEFAULT_COMP_RANGES
|
||||
|
||||
# 默认给药途径
|
||||
if routes is None:
|
||||
routes = ROUTE_OPTIONS
|
||||
|
||||
# 默认评分权重
|
||||
if scoring_weights is None:
|
||||
scoring_weights = DEFAULT_SCORING_WEIGHTS
|
||||
|
||||
# 评分函数(用于 Formulation 对象排序)
|
||||
def _score(f: Formulation) -> float:
|
||||
return compute_formulation_score(f, organ, scoring_weights)
|
||||
|
||||
logger.info(f"Starting optimization for organ: {organ}")
|
||||
logger.info(f"SMILES: {smiles}")
|
||||
logger.info(f"Strategy: num_seeds={num_seeds}, top_per_seed={top_per_seed}, top_k={top_k}")
|
||||
logger.info(f"Step sizes: {step_sizes}")
|
||||
logger.info(f"Routes: {routes}")
|
||||
logger.info(f"Scoring weights: biodist={scoring_weights.biodist_weight}, delivery={scoring_weights.delivery_weight}, size={scoring_weights.size_weight}")
|
||||
logger.info(f"Comp ranges: {comp_ranges.to_dict()}")
|
||||
|
||||
seeds = None
|
||||
|
||||
for iteration, step_size in enumerate(ITERATION_STEP_SIZES):
|
||||
for iteration, step_size in enumerate(step_sizes):
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Iteration {iteration + 1}/{len(ITERATION_STEP_SIZES)}, step_size={step_size}")
|
||||
logger.info(f"Iteration {iteration + 1}/{len(step_sizes)}, step_size={step_size}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
# 生成网格
|
||||
if seeds is None:
|
||||
# 第一次迭代:生成完整初始网格
|
||||
logger.info("Generating initial grid...")
|
||||
grid = generate_initial_grid(step_size)
|
||||
# ==================== 第一次迭代:全局稀疏搜索 ====================
|
||||
logger.info("Generating initial grid (global sparse search)...")
|
||||
grid = generate_initial_grid(step_size, comp_ranges)
|
||||
|
||||
logger.info(f"Grid size: {len(grid)} comp combinations")
|
||||
|
||||
# 扩展到所有 helper lipid 和 route 组合
|
||||
total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(routes)
|
||||
logger.info(f"Total combinations: {total_combinations}")
|
||||
|
||||
# 创建 DataFrame
|
||||
df = create_dataframe_from_formulations(
|
||||
smiles, grid, HELPER_LIPID_OPTIONS, routes
|
||||
)
|
||||
|
||||
# 预测
|
||||
logger.info("Running predictions...")
|
||||
df = predict_biodist(model, df, device, batch_size)
|
||||
|
||||
# 选择 top num_seeds 个种子点
|
||||
seeds = select_top_k(df, organ, num_seeds, scoring_weights)
|
||||
|
||||
logger.info(f"Selected {len(seeds)} seeds for next iteration")
|
||||
|
||||
else:
|
||||
# 后续迭代:围绕种子点精细化
|
||||
logger.info(f"Generating refined grid around {len(seeds)} seeds...")
|
||||
grid = generate_refined_grid(seeds, step_size, radius=2)
|
||||
|
||||
logger.info(f"Grid size: {len(grid)} comp combinations")
|
||||
|
||||
# 扩展到所有 helper lipid 和 route 组合
|
||||
total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(ROUTE_OPTIONS)
|
||||
logger.info(f"Total combinations: {total_combinations}")
|
||||
|
||||
# 创建 DataFrame
|
||||
df = create_dataframe_from_formulations(
|
||||
smiles, grid, HELPER_LIPID_OPTIONS, ROUTE_OPTIONS
|
||||
)
|
||||
|
||||
# 预测
|
||||
logger.info("Running predictions...")
|
||||
df = predict_biodist(model, df, device, batch_size)
|
||||
|
||||
# 选择 top-k
|
||||
seeds = select_top_k(df, organ, top_k)
|
||||
# ==================== 后续迭代:层级局部搜索 ====================
|
||||
# 对每个种子点分别搜索,各自保留局部最优
|
||||
logger.info(f"Hierarchical local search around {len(seeds)} seeds...")
|
||||
|
||||
all_local_best = []
|
||||
|
||||
for seed_idx, seed in enumerate(seeds):
|
||||
# 为当前种子点生成邻域网格
|
||||
local_grid = generate_single_seed_grid(seed, step_size, radius=2, comp_ranges=comp_ranges)
|
||||
|
||||
if len(local_grid) == 0:
|
||||
# 如果没有新的网格点,保留原种子
|
||||
all_local_best.append(seed)
|
||||
continue
|
||||
|
||||
# 创建 DataFrame
|
||||
df = create_dataframe_from_formulations(
|
||||
smiles, local_grid, [seed.helper_lipid], [seed.route]
|
||||
)
|
||||
|
||||
# 预测
|
||||
df = predict_biodist(model, df, device, batch_size)
|
||||
|
||||
# 选择该种子邻域内的 top top_per_seed 个局部最优
|
||||
local_top = select_top_k(df, organ, top_per_seed, scoring_weights)
|
||||
all_local_best.extend(local_top)
|
||||
|
||||
if seed_idx == 0 or (seed_idx + 1) % 5 == 0:
|
||||
logger.info(f" Seed {seed_idx + 1}/{len(seeds)}: local grid size={len(local_grid)}, "
|
||||
f"local best score={_score(local_top[0]):.4f}")
|
||||
|
||||
# 更新种子为所有局部最优点(去重)
|
||||
seen_keys = set()
|
||||
unique_local_best = []
|
||||
# 先按综合评分排序,确保保留最优的
|
||||
all_local_best_sorted = sorted(all_local_best, key=_score, reverse=True)
|
||||
for f in all_local_best_sorted:
|
||||
key = f.unique_key()
|
||||
if key not in seen_keys:
|
||||
seen_keys.add(key)
|
||||
unique_local_best.append(f)
|
||||
|
||||
seeds = unique_local_best
|
||||
logger.info(f"Collected {len(seeds)} unique local best formulations (from {len(all_local_best)} candidates)")
|
||||
|
||||
# 显示当前最优
|
||||
best = seeds[0]
|
||||
logger.info(f"Current best Biodistribution_{organ}: {best.get_biodist(organ):.4f}")
|
||||
best = max(seeds, key=_score)
|
||||
logger.info(f"Current best score: {_score(best):.4f} (biodist_{organ}={best.get_biodist(organ):.4f})")
|
||||
logger.info(f"Best formulation: {best.to_dict()}")
|
||||
|
||||
return seeds
|
||||
# 最终去重、按综合评分排序并返回 top_k
|
||||
seeds_sorted = sorted(seeds, key=_score, reverse=True)
|
||||
|
||||
# 去重:保留每个唯一配方中得分最高的(已排序,所以第一个出现的就是最高的)
|
||||
seen_keys = set()
|
||||
unique_results = []
|
||||
for f in seeds_sorted:
|
||||
key = f.unique_key()
|
||||
if key not in seen_keys:
|
||||
seen_keys.add(key)
|
||||
unique_results.append(f)
|
||||
|
||||
logger.info(f"Final results: {len(unique_results)} unique formulations (from {len(seeds)} candidates)")
|
||||
|
||||
return unique_results[:top_k]
|
||||
|
||||
|
||||
def format_results(formulations: List[Formulation], organ: str) -> pd.DataFrame:
|
||||
@ -481,33 +944,54 @@ def main(
|
||||
help="Output CSV path (optional)"
|
||||
),
|
||||
top_k: int = typer.Option(20, "--top-k", "-k", help="Number of top formulations to return"),
|
||||
num_seeds: Optional[int] = typer.Option(None, "--num-seeds", "-n", help="Number of seed points from first iteration (default: top_k * 5)"),
|
||||
top_per_seed: int = typer.Option(1, "--top-per-seed", "-t", help="Number of local best to keep per seed"),
|
||||
step_sizes: Optional[str] = typer.Option(None, "--step-sizes", "-S", help="Comma-separated step sizes (e.g., '0.10,0.02,0.01')"),
|
||||
batch_size: int = typer.Option(256, "--batch-size", "-b", help="Prediction batch size"),
|
||||
device: str = typer.Option("cuda" if torch.cuda.is_available() else "cpu", "--device", "-d", help="Device"),
|
||||
):
|
||||
"""
|
||||
配方优化程序:寻找最大化目标器官 Biodistribution 的最优 LNP 配方。
|
||||
|
||||
采用层级搜索策略:
|
||||
1. 第一次迭代:全局稀疏搜索,选择 top num_seeds 个分散的种子点
|
||||
2. 后续迭代:对每个种子点分别在其邻域内搜索,各自保留 top_per_seed 个局部最优
|
||||
3. 这样可以保持搜索的多样性,避免结果集中在单一区域
|
||||
|
||||
示例:
|
||||
python -m app.optimize --smiles "CC(C)..." --organ liver
|
||||
python -m app.optimize -s "CC(C)..." -o spleen -k 10
|
||||
python -m app.optimize -s "CC(C)..." -o spleen -k 10 -n 30 -t 2
|
||||
python -m app.optimize -s "CC(C)..." -o liver -S "0.10,0.05,0.02"
|
||||
"""
|
||||
# 验证器官
|
||||
if organ not in AVAILABLE_ORGANS:
|
||||
logger.error(f"Invalid organ: {organ}. Available: {AVAILABLE_ORGANS}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# 解析步长
|
||||
parsed_step_sizes = None
|
||||
if step_sizes:
|
||||
try:
|
||||
parsed_step_sizes = [float(s.strip()) for s in step_sizes.split(",")]
|
||||
except ValueError:
|
||||
logger.error(f"Invalid step sizes format: {step_sizes}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# 加载模型
|
||||
logger.info(f"Loading model from {model_path}")
|
||||
device = torch.device(device)
|
||||
model = load_model(model_path, device)
|
||||
|
||||
# 执行优化
|
||||
# 执行优化(层级搜索策略)
|
||||
results = optimize(
|
||||
smiles=smiles,
|
||||
organ=organ,
|
||||
model=model,
|
||||
device=device,
|
||||
top_k=top_k,
|
||||
num_seeds=num_seeds,
|
||||
top_per_seed=top_per_seed,
|
||||
step_sizes=parsed_step_sizes,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
data/raw/.~internal_deleted_uncorrected.xlsx
Normal file
BIN
data/raw/.~internal_deleted_uncorrected.xlsx
Normal file
Binary file not shown.
54
docker-compose-gpu.yml
Normal file
54
docker-compose-gpu.yml
Normal file
@ -0,0 +1,54 @@
|
||||
services:
|
||||
# FastAPI 后端服务
|
||||
api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: api
|
||||
container_name: lnp-api
|
||||
environment:
|
||||
- MODEL_PATH=/app/models/final/model.pt
|
||||
volumes:
|
||||
# 挂载模型目录以便更新模型
|
||||
- ./models/final:/app/models/final:ro
|
||||
- ./models/mpnn:/app/models/mpnn:ro
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
# Streamlit 前端服务
|
||||
streamlit:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: streamlit
|
||||
container_name: lnp-streamlit
|
||||
ports:
|
||||
- "8501:8501"
|
||||
environment:
|
||||
- API_URL=http://api:8000
|
||||
depends_on:
|
||||
api:
|
||||
condition: service_started
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8501/_stcore/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
|
||||
networks:
|
||||
default:
|
||||
name: lnp-network
|
||||
49
docker-compose.yml
Normal file
49
docker-compose.yml
Normal file
@ -0,0 +1,49 @@
|
||||
services:
|
||||
# FastAPI 后端服务
|
||||
api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: api
|
||||
container_name: lnp-api
|
||||
environment:
|
||||
- MODEL_PATH=/app/models/final/model.pt
|
||||
- DEVICE=cpu
|
||||
volumes:
|
||||
# 挂载模型目录以便更新模型
|
||||
- ./models/final:/app/models/final:ro
|
||||
- ./models/mpnn:/app/models/mpnn:ro
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
|
||||
# Streamlit 前端服务
|
||||
streamlit:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: streamlit
|
||||
container_name: lnp-streamlit
|
||||
ports:
|
||||
- "8501:8501"
|
||||
environment:
|
||||
- API_URL=http://api:8000
|
||||
depends_on:
|
||||
api:
|
||||
condition: service_started
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8501/_stcore/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
|
||||
networks:
|
||||
default:
|
||||
name: lnp-network
|
||||
|
||||
612
lnp_ml/modeling/final_train_optuna_cv.py
Normal file
612
lnp_ml/modeling/final_train_optuna_cv.py
Normal file
@ -0,0 +1,612 @@
|
||||
"""
|
||||
最终训练脚本:3-fold Optuna 调参 + 全量数据训练
|
||||
|
||||
1. 使用全量数据做 3-fold StratifiedKFold Optuna 超参搜索
|
||||
2. 固定最优超参后,使用全量数据训练(不使用 early-stopping)
|
||||
3. 使用 CosineAnnealingLR + 可选 SWA 防止过拟合
|
||||
|
||||
使用方法:
|
||||
python -m lnp_ml.modeling.final_train_optuna_cv
|
||||
|
||||
或通过 Makefile:
|
||||
make final_optuna DEVICE=cuda
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
from loguru import logger
|
||||
import typer
|
||||
|
||||
try:
|
||||
import optuna
|
||||
from optuna.samplers import TPESampler
|
||||
except ImportError:
|
||||
optuna = None
|
||||
TPESampler = None
|
||||
|
||||
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR
|
||||
from lnp_ml.dataset import (
|
||||
LNPDataset,
|
||||
collate_fn,
|
||||
process_dataframe,
|
||||
TARGET_CLASSIFICATION_PDI,
|
||||
TARGET_CLASSIFICATION_EE,
|
||||
TARGET_TOXIC,
|
||||
)
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
from lnp_ml.modeling.trainer_balanced import (
|
||||
ClassWeights,
|
||||
LossWeightsBalanced,
|
||||
compute_class_weights_from_loader,
|
||||
train_with_early_stopping,
|
||||
train_fixed_epochs,
|
||||
)
|
||||
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
|
||||
|
||||
# MPNN ensemble 默认路径
|
||||
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
# ============ CompositeStrata 复合分层标签(复用 nested_cv_optuna 的逻辑) ============
|
||||
|
||||
def build_composite_strata(
|
||||
df: pd.DataFrame,
|
||||
min_stratum_count: int = 5,
|
||||
) -> Tuple[np.ndarray, Dict]:
|
||||
"""
|
||||
构建复合分层标签(toxic × PDI × EE)。
|
||||
|
||||
Args:
|
||||
df: 处理后的 DataFrame
|
||||
min_stratum_count: 每个 stratum 最少样本数,低于此值合并为 RARE
|
||||
|
||||
Returns:
|
||||
(strata_array, strata_info)
|
||||
"""
|
||||
n = len(df)
|
||||
strata_labels = []
|
||||
|
||||
for i in range(n):
|
||||
# Toxic stratum
|
||||
if TARGET_TOXIC in df.columns:
|
||||
toxic_val = df[TARGET_TOXIC].iloc[i]
|
||||
if pd.notna(toxic_val) and toxic_val >= 0:
|
||||
toxic_str = str(int(toxic_val))
|
||||
else:
|
||||
toxic_str = "NA"
|
||||
else:
|
||||
toxic_str = "NA"
|
||||
|
||||
# PDI stratum
|
||||
if all(col in df.columns for col in TARGET_CLASSIFICATION_PDI):
|
||||
pdi_vals = df[TARGET_CLASSIFICATION_PDI].iloc[i].values
|
||||
if pdi_vals.sum() > 0:
|
||||
pdi_str = str(int(np.argmax(pdi_vals)))
|
||||
else:
|
||||
pdi_str = "NA"
|
||||
else:
|
||||
pdi_str = "NA"
|
||||
|
||||
# EE stratum
|
||||
if all(col in df.columns for col in TARGET_CLASSIFICATION_EE):
|
||||
ee_vals = df[TARGET_CLASSIFICATION_EE].iloc[i].values
|
||||
if ee_vals.sum() > 0:
|
||||
ee_str = str(int(np.argmax(ee_vals)))
|
||||
else:
|
||||
ee_str = "NA"
|
||||
else:
|
||||
ee_str = "NA"
|
||||
|
||||
strata_labels.append(f"T{toxic_str}|P{pdi_str}|E{ee_str}")
|
||||
|
||||
# 统计各 stratum 的样本数
|
||||
unique_strata, counts = np.unique(strata_labels, return_counts=True)
|
||||
strata_counts = dict(zip(unique_strata, counts))
|
||||
|
||||
# 将稀疏 strata 合并为 RARE
|
||||
rare_strata = [s for s, c in strata_counts.items() if c < min_stratum_count]
|
||||
|
||||
final_labels = []
|
||||
for label in strata_labels:
|
||||
if label in rare_strata:
|
||||
final_labels.append("RARE")
|
||||
else:
|
||||
final_labels.append(label)
|
||||
|
||||
# 编码为整数
|
||||
unique_final, encoded = np.unique(final_labels, return_inverse=True)
|
||||
|
||||
strata_info = {
|
||||
"original_strata_counts": strata_counts,
|
||||
"rare_strata": rare_strata,
|
||||
"final_strata": list(unique_final),
|
||||
"final_strata_counts": dict(zip(*np.unique(final_labels, return_counts=True))),
|
||||
"n_rare_merged": sum(strata_counts[s] for s in rare_strata) if rare_strata else 0,
|
||||
}
|
||||
|
||||
logger.info(f"Built composite strata: {len(unique_final)} unique strata")
|
||||
logger.info(f" Rare strata merged: {len(rare_strata)} types, {strata_info['n_rare_merged']} samples")
|
||||
|
||||
return encoded.astype(np.int64), strata_info
|
||||
|
||||
|
||||
# ============ 模型创建 ============
|
||||
|
||||
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
|
||||
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||
if not model_paths:
|
||||
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||
return [str(p) for p in model_paths]
|
||||
|
||||
|
||||
def create_model(
|
||||
d_model: int = 256,
|
||||
num_heads: int = 8,
|
||||
n_attn_layers: int = 4,
|
||||
fusion_strategy: str = "attention",
|
||||
head_hidden_dim: int = 128,
|
||||
dropout: float = 0.1,
|
||||
use_mpnn: bool = False,
|
||||
mpnn_device: str = "cpu",
|
||||
) -> Union[LNPModel, LNPModelWithoutMPNN]:
|
||||
"""创建模型"""
|
||||
if use_mpnn:
|
||||
ensemble_paths = find_mpnn_ensemble_paths()
|
||||
return LNPModel(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
mpnn_ensemble_paths=ensemble_paths,
|
||||
mpnn_device=mpnn_device,
|
||||
)
|
||||
else:
|
||||
return LNPModelWithoutMPNN(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
# ============ 预训练权重加载 ============
|
||||
|
||||
def load_pretrain_weights_to_model(
|
||||
model: Union[LNPModel, LNPModelWithoutMPNN],
|
||||
pretrain_state_dict: Dict,
|
||||
d_model: int,
|
||||
pretrain_config: Dict,
|
||||
load_delivery_head: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
加载预训练权重到模型。
|
||||
|
||||
Returns:
|
||||
是否成功加载
|
||||
"""
|
||||
if pretrain_config.get("d_model") != d_model:
|
||||
logger.warning(
|
||||
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
|
||||
f"current={d_model}. Skipping pretrain loading."
|
||||
)
|
||||
return False
|
||||
|
||||
model.load_pretrain_weights(
|
||||
pretrain_state_dict=pretrain_state_dict,
|
||||
load_delivery_head=load_delivery_head,
|
||||
strict=False,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# ============ 3-fold Optuna 调参 ============
|
||||
|
||||
def run_optuna_cv(
|
||||
full_dataset: LNPDataset,
|
||||
strata: np.ndarray,
|
||||
device: torch.device,
|
||||
n_trials: int = 20,
|
||||
epochs_per_trial: int = 30,
|
||||
patience: int = 10,
|
||||
batch_size: int = 32,
|
||||
n_folds: int = 3,
|
||||
use_mpnn: bool = False,
|
||||
seed: int = 42,
|
||||
study_path: Optional[Path] = None,
|
||||
pretrain_state_dict: Optional[Dict] = None,
|
||||
pretrain_config: Optional[Dict] = None,
|
||||
load_delivery_head: bool = True,
|
||||
) -> Tuple[Dict, int, optuna.Study]:
|
||||
"""
|
||||
使用全量数据做 3-fold CV Optuna 超参搜索。
|
||||
|
||||
Args:
|
||||
full_dataset: 完整数据集
|
||||
strata: 每个样本的分层标签
|
||||
device: 设备
|
||||
n_trials: Optuna 试验数
|
||||
epochs_per_trial: 每个试验的最大 epoch
|
||||
patience: 早停耐心值
|
||||
batch_size: 批次大小
|
||||
n_folds: 折数
|
||||
use_mpnn: 是否使用 MPNN
|
||||
seed: 随机种子
|
||||
study_path: 可选的 study 持久化路径
|
||||
pretrain_state_dict: 预训练权重
|
||||
pretrain_config: 预训练配置
|
||||
load_delivery_head: 是否加载 delivery head 权重
|
||||
|
||||
Returns:
|
||||
(best_params, epoch_mean, study)
|
||||
"""
|
||||
if optuna is None:
|
||||
raise ImportError("Optuna not installed. Run: pip install optuna")
|
||||
|
||||
n_samples = len(full_dataset)
|
||||
indices = np.arange(n_samples)
|
||||
|
||||
def objective(trial: optuna.Trial) -> float:
|
||||
# 采样超参数
|
||||
d_model = trial.suggest_categorical("d_model", [128, 256, 512])
|
||||
num_heads = trial.suggest_categorical("num_heads", [4, 8])
|
||||
n_attn_layers = trial.suggest_int("n_attn_layers", 2, 6)
|
||||
fusion_strategy = trial.suggest_categorical(
|
||||
"fusion_strategy", ["attention", "avg", "max"]
|
||||
)
|
||||
head_hidden_dim = trial.suggest_categorical("head_hidden_dim", [64, 128, 256])
|
||||
dropout = trial.suggest_float("dropout", 0.05, 0.3)
|
||||
lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
|
||||
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
|
||||
|
||||
# 3-fold CV
|
||||
cv = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
|
||||
|
||||
fold_val_losses = []
|
||||
fold_best_epochs = []
|
||||
|
||||
for fold, (train_idx, val_idx) in enumerate(cv.split(indices, strata)):
|
||||
# 创建 DataLoader
|
||||
train_subset = Subset(full_dataset, train_idx.tolist())
|
||||
val_subset = Subset(full_dataset, val_idx.tolist())
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# 计算类权重
|
||||
class_weights = compute_class_weights_from_loader(train_loader)
|
||||
|
||||
# 创建模型
|
||||
model = create_model(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
use_mpnn=use_mpnn,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
|
||||
# 加载预训练权重
|
||||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||||
load_pretrain_weights_to_model(
|
||||
model, pretrain_state_dict, d_model, pretrain_config, load_delivery_head
|
||||
)
|
||||
|
||||
# 训练(带早停)
|
||||
result = train_with_early_stopping(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
epochs=epochs_per_trial,
|
||||
patience=patience,
|
||||
class_weights=class_weights,
|
||||
)
|
||||
|
||||
fold_val_losses.append(result["best_val_loss"])
|
||||
fold_best_epochs.append(result["best_epoch"])
|
||||
|
||||
# 记录 epoch_mean 到 trial
|
||||
epoch_mean = int(round(np.mean(fold_best_epochs)))
|
||||
trial.set_user_attr("epoch_mean", epoch_mean)
|
||||
trial.set_user_attr("fold_best_epochs", fold_best_epochs)
|
||||
trial.set_user_attr("fold_val_losses", fold_val_losses)
|
||||
|
||||
return np.mean(fold_val_losses)
|
||||
|
||||
# 创建 study
|
||||
storage = None
|
||||
if study_path is not None:
|
||||
storage = f"sqlite:///{study_path}"
|
||||
|
||||
study = optuna.create_study(
|
||||
direction="minimize",
|
||||
sampler=TPESampler(seed=seed),
|
||||
storage=storage,
|
||||
study_name="final_optuna_cv",
|
||||
load_if_exists=True,
|
||||
)
|
||||
|
||||
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
|
||||
|
||||
logger.info(f"Best trial: {study.best_trial.number}")
|
||||
logger.info(f"Best val_loss: {study.best_trial.value:.4f}")
|
||||
logger.info(f"Best params: {best_params}")
|
||||
logger.info(f"Epoch mean from best trial: {epoch_mean}")
|
||||
|
||||
return best_params, epoch_mean, study
|
||||
|
||||
|
||||
# ============ 主流程 ============
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
|
||||
output_dir: Path = MODELS_DIR / "final_optuna",
|
||||
# CV 参数
|
||||
n_folds: int = 3,
|
||||
min_stratum_count: int = 5,
|
||||
seed: int = 42,
|
||||
# Optuna 参数
|
||||
n_trials: int = 20,
|
||||
epochs_per_trial: int = 30,
|
||||
patience: int = 10,
|
||||
# 训练参数
|
||||
batch_size: int = 32,
|
||||
# 最终训练参数
|
||||
use_swa: bool = False,
|
||||
swa_start_ratio: float = 0.75,
|
||||
# 预训练权重
|
||||
init_from_pretrain: Optional[Path] = None,
|
||||
load_delivery_head: bool = True,
|
||||
# MPNN
|
||||
use_mpnn: bool = False,
|
||||
# 设备
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
):
|
||||
"""
|
||||
最终训练:3-fold Optuna 调参 + 全量数据训练。
|
||||
|
||||
1. 使用全量数据做 3-fold StratifiedKFold Optuna 超参搜索
|
||||
2. 固定最优超参后,使用全量数据训练(不使用 early-stopping)
|
||||
3. epoch 数使用 3-fold CV 中 best trial 的 epoch_mean
|
||||
4. 使用 CosineAnnealingLR,可选 SWA
|
||||
|
||||
使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重。
|
||||
"""
|
||||
if optuna is None:
|
||||
logger.error("Optuna not installed. Run: pip install optuna")
|
||||
raise typer.Exit(1)
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
device = torch.device(device)
|
||||
|
||||
# 加载预训练权重(如果指定)
|
||||
pretrain_state_dict = None
|
||||
pretrain_config = None
|
||||
if init_from_pretrain is not None:
|
||||
if init_from_pretrain.exists():
|
||||
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
|
||||
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
|
||||
pretrain_state_dict = checkpoint["model_state_dict"]
|
||||
pretrain_config = checkpoint.get("config", {})
|
||||
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")
|
||||
else:
|
||||
logger.warning(f"Pretrain checkpoint not found: {init_from_pretrain}, skipping")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Output directory: {output_dir}")
|
||||
|
||||
# 加载数据
|
||||
logger.info(f"Loading data from {input_path}")
|
||||
df = pd.read_csv(input_path)
|
||||
logger.info(f"Loaded {len(df)} samples")
|
||||
|
||||
# 处理数据
|
||||
logger.info("Processing dataframe...")
|
||||
df = process_dataframe(df)
|
||||
|
||||
# 构建复合分层标签
|
||||
logger.info("Building composite strata...")
|
||||
strata, strata_info = build_composite_strata(df, min_stratum_count)
|
||||
|
||||
# 保存 strata 信息
|
||||
with open(output_dir / "strata_info.json", "w") as f:
|
||||
json.dump(strata_info, f, indent=2, default=str)
|
||||
|
||||
# 创建完整数据集
|
||||
full_dataset = LNPDataset(df)
|
||||
|
||||
# 运行 Optuna 调参
|
||||
logger.info(f"\nRunning {n_folds}-fold Optuna with {n_trials} trials...")
|
||||
study_path = output_dir / "optuna_study.sqlite3"
|
||||
|
||||
best_params, epoch_mean, study = run_optuna_cv(
|
||||
full_dataset=full_dataset,
|
||||
strata=strata,
|
||||
device=device,
|
||||
n_trials=n_trials,
|
||||
epochs_per_trial=epochs_per_trial,
|
||||
patience=patience,
|
||||
batch_size=batch_size,
|
||||
n_folds=n_folds,
|
||||
use_mpnn=use_mpnn,
|
||||
seed=seed,
|
||||
study_path=study_path,
|
||||
pretrain_state_dict=pretrain_state_dict,
|
||||
pretrain_config=pretrain_config,
|
||||
load_delivery_head=load_delivery_head,
|
||||
)
|
||||
|
||||
# 保存最佳参数
|
||||
with open(output_dir / "best_params.json", "w") as f:
|
||||
json.dump(best_params, f, indent=2)
|
||||
|
||||
with open(output_dir / "epoch_mean.json", "w") as f:
|
||||
json.dump({"epoch_mean": epoch_mean}, f)
|
||||
|
||||
# 保存 Optuna 试验历史
|
||||
trials_history = []
|
||||
for trial in study.trials:
|
||||
trials_history.append({
|
||||
"number": trial.number,
|
||||
"value": trial.value,
|
||||
"params": trial.params,
|
||||
"user_attrs": trial.user_attrs,
|
||||
"state": str(trial.state),
|
||||
})
|
||||
|
||||
with open(output_dir / "optuna_trials.json", "w") as f:
|
||||
json.dump(trials_history, f, indent=2)
|
||||
|
||||
# 全量数据训练
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info("FINAL TRAINING ON FULL DATA")
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"Using best params with epochs={epoch_mean}")
|
||||
logger.info(f"SWA: {use_swa}")
|
||||
|
||||
# 创建全量 DataLoader
|
||||
full_loader = DataLoader(
|
||||
full_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# 计算类权重
|
||||
class_weights = compute_class_weights_from_loader(full_loader)
|
||||
|
||||
# 保存类权重信息
|
||||
class_weights_info = {
|
||||
"pdi": class_weights.pdi.tolist() if class_weights.pdi is not None else None,
|
||||
"ee": class_weights.ee.tolist() if class_weights.ee is not None else None,
|
||||
"toxic": class_weights.toxic.tolist() if class_weights.toxic is not None else None,
|
||||
}
|
||||
with open(output_dir / "class_weights.json", "w") as f:
|
||||
json.dump(class_weights_info, f, indent=2)
|
||||
|
||||
# 创建模型
|
||||
model = create_model(
|
||||
d_model=best_params["d_model"],
|
||||
num_heads=best_params["num_heads"],
|
||||
n_attn_layers=best_params["n_attn_layers"],
|
||||
fusion_strategy=best_params["fusion_strategy"],
|
||||
head_hidden_dim=best_params["head_hidden_dim"],
|
||||
dropout=best_params["dropout"],
|
||||
use_mpnn=use_mpnn,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
|
||||
# 加载预训练权重
|
||||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||||
loaded = load_pretrain_weights_to_model(
|
||||
model, pretrain_state_dict, best_params["d_model"],
|
||||
pretrain_config, load_delivery_head
|
||||
)
|
||||
if loaded:
|
||||
logger.info("Loaded pretrain weights for final training")
|
||||
|
||||
# 打印模型信息
|
||||
n_params_total = sum(p.numel() for p in model.parameters())
|
||||
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
|
||||
|
||||
# 训练(固定 epoch,不 early-stop)
|
||||
swa_start = int(epoch_mean * swa_start_ratio) if use_swa else None
|
||||
|
||||
train_result = train_fixed_epochs(
|
||||
model=model,
|
||||
train_loader=full_loader,
|
||||
val_loader=None, # 全量训练,无验证集
|
||||
device=device,
|
||||
lr=best_params["lr"],
|
||||
weight_decay=best_params["weight_decay"],
|
||||
epochs=epoch_mean,
|
||||
class_weights=class_weights,
|
||||
use_cosine_annealing=True,
|
||||
use_swa=use_swa,
|
||||
swa_start_epoch=swa_start,
|
||||
)
|
||||
|
||||
# 加载最终权重
|
||||
model.load_state_dict(train_result["final_state"])
|
||||
|
||||
# 保存模型
|
||||
config = {
|
||||
"d_model": best_params["d_model"],
|
||||
"num_heads": best_params["num_heads"],
|
||||
"n_attn_layers": best_params["n_attn_layers"],
|
||||
"fusion_strategy": best_params["fusion_strategy"],
|
||||
"head_hidden_dim": best_params["head_hidden_dim"],
|
||||
"dropout": best_params["dropout"],
|
||||
"use_mpnn": use_mpnn,
|
||||
}
|
||||
|
||||
torch.save({
|
||||
"model_state_dict": train_result["final_state"],
|
||||
"config": config,
|
||||
"best_params": best_params,
|
||||
"epoch_mean": epoch_mean,
|
||||
"use_swa": use_swa,
|
||||
}, output_dir / "model.pt")
|
||||
|
||||
logger.success(f"Saved model to {output_dir / 'model.pt'}")
|
||||
|
||||
# 保存训练历史
|
||||
with open(output_dir / "history.json", "w") as f:
|
||||
json.dump(train_result["history"], f, indent=2)
|
||||
|
||||
# 绘制损失曲线
|
||||
if train_result["history"]["train"]:
|
||||
try:
|
||||
plot_multitask_loss_curves(
|
||||
history=train_result["history"],
|
||||
output_path=output_dir / "loss_curves.png",
|
||||
title="Final Training Loss Curves",
|
||||
)
|
||||
logger.info(f"Saved loss curves to {output_dir / 'loss_curves.png'}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to plot loss curves: {e}")
|
||||
|
||||
# 打印最终信息
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info("FINAL TRAINING COMPLETE")
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"Best params: {best_params}")
|
||||
logger.info(f"Epochs trained: {epoch_mean}")
|
||||
logger.info(f"Output directory: {output_dir}")
|
||||
|
||||
# 提示如何使用
|
||||
logger.info("\n[How to use the trained model]")
|
||||
logger.info(f" 1. Set environment variable: MODEL_PATH={output_dir / 'model.pt'}")
|
||||
logger.info(" 2. Or copy to default location: cp {output_dir}/model.pt models/final/model.pt")
|
||||
logger.info(" 3. Start API: make api")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
774
lnp_ml/modeling/nested_cv_optuna.py
Normal file
774
lnp_ml/modeling/nested_cv_optuna.py
Normal file
@ -0,0 +1,774 @@
|
||||
"""
|
||||
嵌套交叉验证 + Optuna 超参调优
|
||||
|
||||
外层 5-fold StratifiedKFold(20% test / 80% train)
|
||||
内层 3-fold StratifiedKFold(在 80% 上做 Optuna 超参搜索)
|
||||
|
||||
使用方法:
|
||||
python -m lnp_ml.modeling.nested_cv_optuna
|
||||
|
||||
或通过 Makefile:
|
||||
make nested_cv_tune DEVICE=cuda
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from sklearn.model_selection import StratifiedKFold
|
||||
from loguru import logger
|
||||
import typer
|
||||
|
||||
try:
|
||||
import optuna
|
||||
from optuna.samplers import TPESampler
|
||||
except ImportError:
|
||||
optuna = None
|
||||
TPESampler = None
|
||||
|
||||
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR
|
||||
from lnp_ml.dataset import (
|
||||
LNPDataset,
|
||||
collate_fn,
|
||||
process_dataframe,
|
||||
TARGET_CLASSIFICATION_PDI,
|
||||
TARGET_CLASSIFICATION_EE,
|
||||
TARGET_TOXIC,
|
||||
)
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
from lnp_ml.modeling.trainer_balanced import (
|
||||
ClassWeights,
|
||||
LossWeightsBalanced,
|
||||
compute_class_weights_from_loader,
|
||||
train_with_early_stopping,
|
||||
train_fixed_epochs,
|
||||
validate_balanced,
|
||||
)
|
||||
|
||||
# MPNN ensemble 默认路径
|
||||
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
# ============ CompositeStrata 复合分层标签 ============
|
||||
|
||||
def build_composite_strata(
|
||||
df: pd.DataFrame,
|
||||
min_stratum_count: int = 5,
|
||||
) -> Tuple[np.ndarray, Dict]:
|
||||
"""
|
||||
构建复合分层标签(toxic × PDI × EE)。
|
||||
|
||||
Args:
|
||||
df: 处理后的 DataFrame
|
||||
min_stratum_count: 每个 stratum 最少样本数,低于此值合并为 RARE
|
||||
|
||||
Returns:
|
||||
(strata_array, strata_info)
|
||||
- strata_array: 每个样本的 stratum 编码(整数)
|
||||
- strata_info: 统计信息
|
||||
"""
|
||||
n = len(df)
|
||||
strata_labels = []
|
||||
|
||||
for i in range(n):
|
||||
# Toxic stratum
|
||||
if TARGET_TOXIC in df.columns:
|
||||
toxic_val = df[TARGET_TOXIC].iloc[i]
|
||||
if pd.notna(toxic_val) and toxic_val >= 0:
|
||||
toxic_str = str(int(toxic_val))
|
||||
else:
|
||||
toxic_str = "NA"
|
||||
else:
|
||||
toxic_str = "NA"
|
||||
|
||||
# PDI stratum
|
||||
if all(col in df.columns for col in TARGET_CLASSIFICATION_PDI):
|
||||
pdi_vals = df[TARGET_CLASSIFICATION_PDI].iloc[i].values
|
||||
if pdi_vals.sum() > 0:
|
||||
pdi_str = str(int(np.argmax(pdi_vals)))
|
||||
else:
|
||||
pdi_str = "NA"
|
||||
else:
|
||||
pdi_str = "NA"
|
||||
|
||||
# EE stratum
|
||||
if all(col in df.columns for col in TARGET_CLASSIFICATION_EE):
|
||||
ee_vals = df[TARGET_CLASSIFICATION_EE].iloc[i].values
|
||||
if ee_vals.sum() > 0:
|
||||
ee_str = str(int(np.argmax(ee_vals)))
|
||||
else:
|
||||
ee_str = "NA"
|
||||
else:
|
||||
ee_str = "NA"
|
||||
|
||||
strata_labels.append(f"T{toxic_str}|P{pdi_str}|E{ee_str}")
|
||||
|
||||
# 统计各 stratum 的样本数
|
||||
unique_strata, counts = np.unique(strata_labels, return_counts=True)
|
||||
strata_counts = dict(zip(unique_strata, counts))
|
||||
|
||||
# 将稀疏 strata 合并为 RARE
|
||||
rare_strata = [s for s, c in strata_counts.items() if c < min_stratum_count]
|
||||
|
||||
final_labels = []
|
||||
for label in strata_labels:
|
||||
if label in rare_strata:
|
||||
final_labels.append("RARE")
|
||||
else:
|
||||
final_labels.append(label)
|
||||
|
||||
# 编码为整数
|
||||
unique_final, encoded = np.unique(final_labels, return_inverse=True)
|
||||
|
||||
strata_info = {
|
||||
"original_strata_counts": strata_counts,
|
||||
"rare_strata": rare_strata,
|
||||
"final_strata": list(unique_final),
|
||||
"final_strata_counts": dict(zip(*np.unique(final_labels, return_counts=True))),
|
||||
"n_rare_merged": sum(strata_counts[s] for s in rare_strata) if rare_strata else 0,
|
||||
}
|
||||
|
||||
logger.info(f"Built composite strata: {len(unique_final)} unique strata")
|
||||
logger.info(f" Rare strata merged: {len(rare_strata)} types, {strata_info['n_rare_merged']} samples")
|
||||
|
||||
return encoded.astype(np.int64), strata_info
|
||||
|
||||
|
||||
# ============ 模型创建 ============
|
||||
|
||||
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
|
||||
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||
if not model_paths:
|
||||
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||
return [str(p) for p in model_paths]
|
||||
|
||||
|
||||
def create_model(
|
||||
d_model: int = 256,
|
||||
num_heads: int = 8,
|
||||
n_attn_layers: int = 4,
|
||||
fusion_strategy: str = "attention",
|
||||
head_hidden_dim: int = 128,
|
||||
dropout: float = 0.1,
|
||||
use_mpnn: bool = False,
|
||||
mpnn_device: str = "cpu",
|
||||
) -> Union[LNPModel, LNPModelWithoutMPNN]:
|
||||
"""创建模型"""
|
||||
if use_mpnn:
|
||||
ensemble_paths = find_mpnn_ensemble_paths()
|
||||
return LNPModel(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
mpnn_ensemble_paths=ensemble_paths,
|
||||
mpnn_device=mpnn_device,
|
||||
)
|
||||
else:
|
||||
return LNPModelWithoutMPNN(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
# ============ 评估指标 ============
|
||||
|
||||
def evaluate_on_test(
|
||||
model: torch.nn.Module,
|
||||
test_loader: DataLoader,
|
||||
device: torch.device,
|
||||
) -> Dict:
|
||||
"""在测试集上评估模型"""
|
||||
from scipy.special import rel_entr
|
||||
from sklearn.metrics import (
|
||||
mean_squared_error,
|
||||
mean_absolute_error,
|
||||
r2_score,
|
||||
accuracy_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
f1_score,
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
preds = {
|
||||
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
|
||||
}
|
||||
targets = {
|
||||
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in test_loader:
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
tgts = batch["targets"]
|
||||
masks = batch["mask"]
|
||||
|
||||
outputs = model(smiles, tabular)
|
||||
|
||||
# 收集预测和真实值
|
||||
for task in ["size", "delivery"]:
|
||||
if task in masks and masks[task].any():
|
||||
m = masks[task]
|
||||
key = task if task == "size" else "delivery"
|
||||
preds[task].extend(outputs[key].squeeze(-1)[m].cpu().numpy().tolist())
|
||||
targets[task].extend(tgts[key][m].cpu().numpy().tolist())
|
||||
|
||||
for task in ["pdi", "ee", "toxic"]:
|
||||
if task in masks and masks[task].any():
|
||||
m = masks[task]
|
||||
preds[task].extend(outputs[task][m].argmax(dim=-1).cpu().numpy().tolist())
|
||||
targets[task].extend(tgts[task][m].cpu().numpy().tolist())
|
||||
|
||||
if "biodist" in masks and masks["biodist"].any():
|
||||
m = masks["biodist"]
|
||||
preds["biodist"].extend(outputs["biodist"][m].cpu().numpy().tolist())
|
||||
targets["biodist"].extend(tgts["biodist"][m].cpu().numpy().tolist())
|
||||
|
||||
# 计算指标
|
||||
results = {}
|
||||
|
||||
# 回归任务
|
||||
for task in ["size", "delivery"]:
|
||||
if preds[task]:
|
||||
p = np.array(preds[task])
|
||||
t = np.array(targets[task])
|
||||
results[task] = {
|
||||
"n_samples": len(p),
|
||||
"mse": float(mean_squared_error(t, p)),
|
||||
"rmse": float(np.sqrt(mean_squared_error(t, p))),
|
||||
"mae": float(mean_absolute_error(t, p)),
|
||||
"r2": float(r2_score(t, p)),
|
||||
}
|
||||
|
||||
# 分类任务
|
||||
for task in ["pdi", "ee", "toxic"]:
|
||||
if preds[task]:
|
||||
p = np.array(preds[task])
|
||||
t = np.array(targets[task])
|
||||
results[task] = {
|
||||
"n_samples": len(p),
|
||||
"accuracy": float(accuracy_score(t, p)),
|
||||
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
|
||||
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
|
||||
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
|
||||
}
|
||||
|
||||
# 分布任务
|
||||
if preds["biodist"]:
|
||||
p = np.array(preds["biodist"])
|
||||
t = np.array(targets["biodist"])
|
||||
|
||||
def kl_divergence(p_arr, q_arr, eps=1e-10):
|
||||
p_arr = np.clip(p_arr, eps, 1.0)
|
||||
q_arr = np.clip(q_arr, eps, 1.0)
|
||||
return float(np.sum(rel_entr(p_arr, q_arr), axis=-1).mean())
|
||||
|
||||
def js_divergence(p_arr, q_arr, eps=1e-10):
|
||||
p_arr = np.clip(p_arr, eps, 1.0)
|
||||
q_arr = np.clip(q_arr, eps, 1.0)
|
||||
m = 0.5 * (p_arr + q_arr)
|
||||
return float(0.5 * (np.sum(rel_entr(p_arr, m), axis=-1) + np.sum(rel_entr(q_arr, m), axis=-1)).mean())
|
||||
|
||||
results["biodist"] = {
|
||||
"n_samples": len(p),
|
||||
"kl_divergence": kl_divergence(t, p),
|
||||
"js_divergence": js_divergence(t, p),
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============ 预训练权重加载 ============
|
||||
|
||||
def load_pretrain_weights_to_model(
|
||||
model: Union[LNPModel, LNPModelWithoutMPNN],
|
||||
pretrain_state_dict: Dict,
|
||||
d_model: int,
|
||||
pretrain_config: Dict,
|
||||
load_delivery_head: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
加载预训练权重到模型。
|
||||
|
||||
Returns:
|
||||
是否成功加载
|
||||
"""
|
||||
if pretrain_config.get("d_model") != d_model:
|
||||
logger.warning(
|
||||
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
|
||||
f"current={d_model}. Skipping pretrain loading."
|
||||
)
|
||||
return False
|
||||
|
||||
model.load_pretrain_weights(
|
||||
pretrain_state_dict=pretrain_state_dict,
|
||||
load_delivery_head=load_delivery_head,
|
||||
strict=False,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# ============ 内层 Optuna 调参 ============
|
||||
|
||||
def run_inner_optuna(
|
||||
full_dataset: LNPDataset,
|
||||
inner_train_indices: np.ndarray,
|
||||
strata: np.ndarray,
|
||||
device: torch.device,
|
||||
n_trials: int = 20,
|
||||
epochs_per_trial: int = 30,
|
||||
patience: int = 10,
|
||||
batch_size: int = 32,
|
||||
n_inner_folds: int = 3,
|
||||
use_mpnn: bool = False,
|
||||
seed: int = 42,
|
||||
study_path: Optional[Path] = None,
|
||||
pretrain_state_dict: Optional[Dict] = None,
|
||||
pretrain_config: Optional[Dict] = None,
|
||||
load_delivery_head: bool = True,
|
||||
) -> Tuple[Dict, int, optuna.Study]:
|
||||
"""
|
||||
在内层数据上运行 Optuna 超参搜索。
|
||||
|
||||
Args:
|
||||
full_dataset: 完整数据集
|
||||
inner_train_indices: 内层训练数据的索引(相对于 full_dataset)
|
||||
strata: 每个样本的分层标签
|
||||
device: 设备
|
||||
n_trials: Optuna 试验数
|
||||
epochs_per_trial: 每个试验的最大 epoch
|
||||
patience: 早停耐心值
|
||||
batch_size: 批次大小
|
||||
n_inner_folds: 内层折数
|
||||
use_mpnn: 是否使用 MPNN
|
||||
seed: 随机种子
|
||||
study_path: 可选的 study 持久化路径
|
||||
pretrain_state_dict: 预训练权重
|
||||
pretrain_config: 预训练配置
|
||||
load_delivery_head: 是否加载 delivery head 权重
|
||||
|
||||
Returns:
|
||||
(best_params, epoch_mean, study)
|
||||
"""
|
||||
if optuna is None:
|
||||
raise ImportError("Optuna not installed. Run: pip install optuna")
|
||||
|
||||
inner_strata = strata[inner_train_indices]
|
||||
|
||||
def objective(trial: optuna.Trial) -> float:
|
||||
# 采样超参数
|
||||
d_model = trial.suggest_categorical("d_model", [128, 256, 512])
|
||||
num_heads = trial.suggest_categorical("num_heads", [4, 8])
|
||||
n_attn_layers = trial.suggest_int("n_attn_layers", 2, 6)
|
||||
fusion_strategy = trial.suggest_categorical(
|
||||
"fusion_strategy", ["attention", "avg", "max"]
|
||||
)
|
||||
head_hidden_dim = trial.suggest_categorical("head_hidden_dim", [64, 128, 256])
|
||||
dropout = trial.suggest_float("dropout", 0.05, 0.3)
|
||||
lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
|
||||
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
|
||||
|
||||
# 内层 3-fold CV
|
||||
inner_cv = StratifiedKFold(
|
||||
n_splits=n_inner_folds, shuffle=True, random_state=seed
|
||||
)
|
||||
|
||||
fold_val_losses = []
|
||||
fold_best_epochs = []
|
||||
|
||||
for inner_fold, (inner_train_idx, inner_val_idx) in enumerate(
|
||||
inner_cv.split(inner_train_indices, inner_strata)
|
||||
):
|
||||
# 获取实际的数据集索引
|
||||
actual_train_idx = inner_train_indices[inner_train_idx]
|
||||
actual_val_idx = inner_train_indices[inner_val_idx]
|
||||
|
||||
# 创建 DataLoader
|
||||
train_subset = Subset(full_dataset, actual_train_idx.tolist())
|
||||
val_subset = Subset(full_dataset, actual_val_idx.tolist())
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# 计算类权重
|
||||
class_weights = compute_class_weights_from_loader(train_loader)
|
||||
|
||||
# 创建模型
|
||||
model = create_model(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
use_mpnn=use_mpnn,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
|
||||
# 加载预训练权重
|
||||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||||
load_pretrain_weights_to_model(
|
||||
model, pretrain_state_dict, d_model, pretrain_config, load_delivery_head
|
||||
)
|
||||
|
||||
# 训练(带早停)
|
||||
result = train_with_early_stopping(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
device=device,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
epochs=epochs_per_trial,
|
||||
patience=patience,
|
||||
class_weights=class_weights,
|
||||
)
|
||||
|
||||
fold_val_losses.append(result["best_val_loss"])
|
||||
fold_best_epochs.append(result["best_epoch"])
|
||||
|
||||
# 记录 epoch_mean 到 trial
|
||||
epoch_mean = int(round(np.mean(fold_best_epochs)))
|
||||
trial.set_user_attr("epoch_mean", epoch_mean)
|
||||
trial.set_user_attr("fold_best_epochs", fold_best_epochs)
|
||||
|
||||
return np.mean(fold_val_losses)
|
||||
|
||||
# 创建 study
|
||||
storage = None
|
||||
if study_path is not None:
|
||||
storage = f"sqlite:///{study_path}"
|
||||
|
||||
study = optuna.create_study(
|
||||
direction="minimize",
|
||||
sampler=TPESampler(seed=seed),
|
||||
storage=storage,
|
||||
study_name="inner_optuna",
|
||||
load_if_exists=True,
|
||||
)
|
||||
|
||||
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
|
||||
|
||||
logger.info(f"Best trial: {study.best_trial.number}")
|
||||
logger.info(f"Best val_loss: {study.best_trial.value:.4f}")
|
||||
logger.info(f"Best params: {best_params}")
|
||||
logger.info(f"Epoch mean: {epoch_mean}")
|
||||
|
||||
return best_params, epoch_mean, study
|
||||
|
||||
|
||||
# ============ 主流程 ============
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
|
||||
output_dir: Path = MODELS_DIR / "nested_cv",
|
||||
# CV 参数
|
||||
n_outer_folds: int = 5,
|
||||
n_inner_folds: int = 3,
|
||||
min_stratum_count: int = 5,
|
||||
seed: int = 42,
|
||||
# Optuna 参数
|
||||
n_trials: int = 20,
|
||||
epochs_per_trial: int = 30,
|
||||
inner_patience: int = 10,
|
||||
# 训练参数
|
||||
batch_size: int = 32,
|
||||
# 预训练权重
|
||||
init_from_pretrain: Optional[Path] = None,
|
||||
load_delivery_head: bool = True,
|
||||
# MPNN
|
||||
use_mpnn: bool = False,
|
||||
# 设备
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
):
|
||||
"""
|
||||
嵌套交叉验证 + Optuna 超参调优。
|
||||
|
||||
外层 5-fold(20% test / 80% train),内层 3-fold Optuna 调参。
|
||||
外层训练不使用 early-stopping,epoch 数使用内层 best trial 的 epoch_mean。
|
||||
|
||||
使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重。
|
||||
"""
|
||||
if optuna is None:
|
||||
logger.error("Optuna not installed. Run: pip install optuna")
|
||||
raise typer.Exit(1)
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
device = torch.device(device)
|
||||
|
||||
# 加载预训练权重(如果指定)
|
||||
pretrain_state_dict = None
|
||||
pretrain_config = None
|
||||
if init_from_pretrain is not None:
|
||||
if init_from_pretrain.exists():
|
||||
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
|
||||
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
|
||||
pretrain_state_dict = checkpoint["model_state_dict"]
|
||||
pretrain_config = checkpoint.get("config", {})
|
||||
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")
|
||||
else:
|
||||
logger.warning(f"Pretrain checkpoint not found: {init_from_pretrain}, skipping")
|
||||
|
||||
# 创建输出目录(带时间戳)
|
||||
run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
run_dir = output_dir / run_name
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Output directory: {run_dir}")
|
||||
|
||||
# 加载数据
|
||||
logger.info(f"Loading data from {input_path}")
|
||||
df = pd.read_csv(input_path)
|
||||
logger.info(f"Loaded {len(df)} samples")
|
||||
|
||||
# 处理数据
|
||||
logger.info("Processing dataframe...")
|
||||
df = process_dataframe(df)
|
||||
|
||||
# 构建复合分层标签
|
||||
logger.info("Building composite strata...")
|
||||
strata, strata_info = build_composite_strata(df, min_stratum_count)
|
||||
|
||||
# 保存 strata 信息
|
||||
with open(run_dir / "strata_info.json", "w") as f:
|
||||
json.dump(strata_info, f, indent=2, default=str)
|
||||
|
||||
# 创建完整数据集
|
||||
full_dataset = LNPDataset(df)
|
||||
n_samples = len(full_dataset)
|
||||
|
||||
# 外层 CV
|
||||
outer_cv = StratifiedKFold(
|
||||
n_splits=n_outer_folds, shuffle=True, random_state=seed
|
||||
)
|
||||
|
||||
outer_results = []
|
||||
|
||||
for outer_fold, (outer_train_idx, outer_test_idx) in enumerate(
|
||||
outer_cv.split(np.arange(n_samples), strata)
|
||||
):
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"OUTER FOLD {outer_fold}")
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}")
|
||||
|
||||
fold_dir = run_dir / f"outer_fold_{outer_fold}"
|
||||
fold_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存 split indices
|
||||
splits = {
|
||||
"outer_train_idx": outer_train_idx.tolist(),
|
||||
"outer_test_idx": outer_test_idx.tolist(),
|
||||
}
|
||||
with open(fold_dir / "splits.json", "w") as f:
|
||||
json.dump(splits, f)
|
||||
|
||||
# 内层 Optuna 调参
|
||||
logger.info(f"\nRunning inner Optuna with {n_trials} trials...")
|
||||
study_path = fold_dir / "optuna_study.sqlite3"
|
||||
|
||||
best_params, epoch_mean, study = run_inner_optuna(
|
||||
full_dataset=full_dataset,
|
||||
inner_train_indices=outer_train_idx,
|
||||
strata=strata,
|
||||
device=device,
|
||||
n_trials=n_trials,
|
||||
epochs_per_trial=epochs_per_trial,
|
||||
patience=inner_patience,
|
||||
batch_size=batch_size,
|
||||
n_inner_folds=n_inner_folds,
|
||||
use_mpnn=use_mpnn,
|
||||
seed=seed + outer_fold,
|
||||
study_path=study_path,
|
||||
pretrain_state_dict=pretrain_state_dict,
|
||||
pretrain_config=pretrain_config,
|
||||
load_delivery_head=load_delivery_head,
|
||||
)
|
||||
|
||||
# 保存最佳参数
|
||||
with open(fold_dir / "best_params.json", "w") as f:
|
||||
json.dump(best_params, f, indent=2)
|
||||
|
||||
with open(fold_dir / "epoch_mean.json", "w") as f:
|
||||
json.dump({"epoch_mean": epoch_mean}, f)
|
||||
|
||||
# 外层训练(使用最优超参,固定 epoch 数,不 early-stop)
|
||||
logger.info(f"\nTraining outer fold with best params, epochs={epoch_mean}...")
|
||||
|
||||
# 创建 DataLoader
|
||||
train_subset = Subset(full_dataset, outer_train_idx.tolist())
|
||||
test_subset = Subset(full_dataset, outer_test_idx.tolist())
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# 计算类权重
|
||||
class_weights = compute_class_weights_from_loader(train_loader)
|
||||
|
||||
# 创建模型
|
||||
model = create_model(
|
||||
d_model=best_params["d_model"],
|
||||
num_heads=best_params["num_heads"],
|
||||
n_attn_layers=best_params["n_attn_layers"],
|
||||
fusion_strategy=best_params["fusion_strategy"],
|
||||
head_hidden_dim=best_params["head_hidden_dim"],
|
||||
dropout=best_params["dropout"],
|
||||
use_mpnn=use_mpnn,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
|
||||
# 加载预训练权重
|
||||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||||
loaded = load_pretrain_weights_to_model(
|
||||
model, pretrain_state_dict, best_params["d_model"],
|
||||
pretrain_config, load_delivery_head
|
||||
)
|
||||
if loaded:
|
||||
logger.info(f"Loaded pretrain weights for outer fold {outer_fold}")
|
||||
|
||||
# 训练(固定 epoch,不 early-stop)
|
||||
train_result = train_fixed_epochs(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=None, # 外层不用验证集
|
||||
device=device,
|
||||
lr=best_params["lr"],
|
||||
weight_decay=best_params["weight_decay"],
|
||||
epochs=epoch_mean,
|
||||
class_weights=class_weights,
|
||||
use_cosine_annealing=True,
|
||||
)
|
||||
|
||||
# 加载最终权重
|
||||
model.load_state_dict(train_result["final_state"])
|
||||
model = model.to(device)
|
||||
|
||||
# 保存模型
|
||||
config = {
|
||||
"d_model": best_params["d_model"],
|
||||
"num_heads": best_params["num_heads"],
|
||||
"n_attn_layers": best_params["n_attn_layers"],
|
||||
"fusion_strategy": best_params["fusion_strategy"],
|
||||
"head_hidden_dim": best_params["head_hidden_dim"],
|
||||
"dropout": best_params["dropout"],
|
||||
"use_mpnn": use_mpnn,
|
||||
}
|
||||
|
||||
torch.save({
|
||||
"model_state_dict": train_result["final_state"],
|
||||
"config": config,
|
||||
"epoch_mean": epoch_mean,
|
||||
"best_params": best_params,
|
||||
}, fold_dir / "model.pt")
|
||||
|
||||
# 保存训练历史
|
||||
with open(fold_dir / "history.json", "w") as f:
|
||||
json.dump(train_result["history"], f, indent=2)
|
||||
|
||||
# 在测试集上评估
|
||||
logger.info("Evaluating on outer test set...")
|
||||
test_metrics = evaluate_on_test(model, test_loader, device)
|
||||
|
||||
with open(fold_dir / "test_metrics.json", "w") as f:
|
||||
json.dump(test_metrics, f, indent=2)
|
||||
|
||||
# 打印测试结果
|
||||
logger.info(f"\nOuter Fold {outer_fold} Test Results:")
|
||||
for task, metrics in test_metrics.items():
|
||||
if "rmse" in metrics:
|
||||
logger.info(f" {task}: RMSE={metrics['rmse']:.4f}, R²={metrics['r2']:.4f}")
|
||||
elif "accuracy" in metrics:
|
||||
logger.info(f" {task}: Acc={metrics['accuracy']:.4f}, F1={metrics['f1']:.4f}")
|
||||
elif "kl_divergence" in metrics:
|
||||
logger.info(f" {task}: KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}")
|
||||
|
||||
outer_results.append({
|
||||
"fold": outer_fold,
|
||||
"best_params": best_params,
|
||||
"epoch_mean": epoch_mean,
|
||||
"test_metrics": test_metrics,
|
||||
})
|
||||
|
||||
# 汇总结果
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("NESTED CV COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 计算汇总统计
|
||||
summary = {"fold_results": outer_results}
|
||||
|
||||
# 对每个任务计算均值和标准差
|
||||
tasks_with_metrics = {}
|
||||
for result in outer_results:
|
||||
for task, metrics in result["test_metrics"].items():
|
||||
if task not in tasks_with_metrics:
|
||||
tasks_with_metrics[task] = {k: [] for k in metrics.keys() if k != "n_samples"}
|
||||
for k, v in metrics.items():
|
||||
if k != "n_samples":
|
||||
tasks_with_metrics[task][k].append(v)
|
||||
|
||||
summary["summary_stats"] = {}
|
||||
for task, metrics_dict in tasks_with_metrics.items():
|
||||
summary["summary_stats"][task] = {}
|
||||
for metric_name, values in metrics_dict.items():
|
||||
summary["summary_stats"][task][f"{metric_name}_mean"] = float(np.mean(values))
|
||||
summary["summary_stats"][task][f"{metric_name}_std"] = float(np.std(values))
|
||||
|
||||
# 打印汇总
|
||||
logger.info("\n[Summary Statistics]")
|
||||
for task, stats in summary["summary_stats"].items():
|
||||
if "rmse_mean" in stats:
|
||||
logger.info(
|
||||
f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, "
|
||||
f"R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}"
|
||||
)
|
||||
elif "accuracy_mean" in stats:
|
||||
logger.info(
|
||||
f" {task}: Acc={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}, "
|
||||
f"F1={stats['f1_mean']:.4f}±{stats['f1_std']:.4f}"
|
||||
)
|
||||
elif "kl_divergence_mean" in stats:
|
||||
logger.info(
|
||||
f" {task}: KL={stats['kl_divergence_mean']:.4f}±{stats['kl_divergence_std']:.4f}, "
|
||||
f"JS={stats['js_divergence_mean']:.4f}±{stats['js_divergence_std']:.4f}"
|
||||
)
|
||||
|
||||
# 保存汇总
|
||||
with open(run_dir / "summary.json", "w") as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
logger.success(f"\nAll results saved to {run_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
474
lnp_ml/modeling/trainer_balanced.py
Normal file
474
lnp_ml/modeling/trainer_balanced.py
Normal file
@ -0,0 +1,474 @@
|
||||
"""带类权重的训练器:处理分类任务的数据不均衡问题"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassWeights:
|
||||
"""分类任务的类权重"""
|
||||
pdi: Optional[torch.Tensor] = None # [4] for 4 PDI classes
|
||||
ee: Optional[torch.Tensor] = None # [3] for 3 EE classes
|
||||
toxic: Optional[torch.Tensor] = None # [2] for binary toxic
|
||||
|
||||
|
||||
@dataclass
|
||||
class LossWeightsBalanced:
|
||||
"""各任务的损失权重(与 trainer.py 的 LossWeights 兼容)"""
|
||||
size: float = 1.0
|
||||
pdi: float = 1.0
|
||||
ee: float = 1.0
|
||||
delivery: float = 1.0
|
||||
biodist: float = 1.0
|
||||
toxic: float = 1.0
|
||||
|
||||
|
||||
def compute_class_weights_from_loader(
|
||||
loader: DataLoader,
|
||||
n_pdi_classes: int = 4,
|
||||
n_ee_classes: int = 3,
|
||||
n_toxic_classes: int = 2,
|
||||
smoothing: float = 0.1,
|
||||
) -> ClassWeights:
|
||||
"""
|
||||
从 DataLoader 统计类别频次并计算类权重。
|
||||
|
||||
使用 inverse frequency 方式:weight_c = N / (n_classes * count_c)
|
||||
加 smoothing 避免极端权重。
|
||||
|
||||
Args:
|
||||
loader: DataLoader(需要遍历一次)
|
||||
n_pdi_classes: PDI 类别数
|
||||
n_ee_classes: EE 类别数
|
||||
n_toxic_classes: toxic 类别数
|
||||
smoothing: 平滑系数(防止除零和极端权重)
|
||||
|
||||
Returns:
|
||||
ClassWeights 对象,包含各分类任务的类权重张量
|
||||
"""
|
||||
pdi_counts = torch.zeros(n_pdi_classes)
|
||||
ee_counts = torch.zeros(n_ee_classes)
|
||||
toxic_counts = torch.zeros(n_toxic_classes)
|
||||
|
||||
for batch in loader:
|
||||
targets = batch["targets"]
|
||||
mask = batch["mask"]
|
||||
|
||||
# PDI
|
||||
if "pdi" in targets and "pdi" in mask:
|
||||
m = mask["pdi"]
|
||||
if m.any():
|
||||
labels = targets["pdi"][m]
|
||||
for c in range(n_pdi_classes):
|
||||
pdi_counts[c] += (labels == c).sum().item()
|
||||
|
||||
# EE
|
||||
if "ee" in targets and "ee" in mask:
|
||||
m = mask["ee"]
|
||||
if m.any():
|
||||
labels = targets["ee"][m]
|
||||
for c in range(n_ee_classes):
|
||||
ee_counts[c] += (labels == c).sum().item()
|
||||
|
||||
# Toxic
|
||||
if "toxic" in targets and "toxic" in mask:
|
||||
m = mask["toxic"]
|
||||
if m.any():
|
||||
labels = targets["toxic"][m]
|
||||
for c in range(n_toxic_classes):
|
||||
toxic_counts[c] += (labels == c).sum().item()
|
||||
|
||||
def counts_to_weights(counts: torch.Tensor, n_classes: int) -> Optional[torch.Tensor]:
|
||||
"""将计数转换为类权重"""
|
||||
total = counts.sum().item()
|
||||
if total == 0:
|
||||
return None
|
||||
# Inverse frequency with smoothing
|
||||
counts = counts + smoothing
|
||||
weights = total / (n_classes * counts)
|
||||
# Normalize to mean=1
|
||||
weights = weights / weights.mean()
|
||||
return weights
|
||||
|
||||
return ClassWeights(
|
||||
pdi=counts_to_weights(pdi_counts, n_pdi_classes),
|
||||
ee=counts_to_weights(ee_counts, n_ee_classes),
|
||||
toxic=counts_to_weights(toxic_counts, n_toxic_classes),
|
||||
)
|
||||
|
||||
|
||||
def compute_multitask_loss_balanced(
|
||||
outputs: Dict[str, torch.Tensor],
|
||||
targets: Dict[str, torch.Tensor],
|
||||
mask: Dict[str, torch.Tensor],
|
||||
task_weights: Optional[LossWeightsBalanced] = None,
|
||||
class_weights: Optional[ClassWeights] = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
计算带类权重的多任务损失。
|
||||
|
||||
Args:
|
||||
outputs: 模型输出
|
||||
targets: 真实标签
|
||||
mask: 有效样本掩码
|
||||
task_weights: 各任务权重
|
||||
class_weights: 分类任务的类权重
|
||||
|
||||
Returns:
|
||||
(total_loss, loss_dict) 总损失和各任务损失
|
||||
"""
|
||||
task_weights = task_weights or LossWeightsBalanced()
|
||||
class_weights = class_weights or ClassWeights()
|
||||
|
||||
losses = {}
|
||||
device = next(iter(outputs.values())).device
|
||||
total_loss = torch.tensor(0.0, device=device)
|
||||
|
||||
# size: MSE loss(回归任务,不需要类权重)
|
||||
if "size" in targets and mask["size"].any():
|
||||
m = mask["size"]
|
||||
pred = outputs["size"][m].squeeze(-1)
|
||||
tgt = targets["size"][m]
|
||||
losses["size"] = F.mse_loss(pred, tgt)
|
||||
total_loss = total_loss + task_weights.size * losses["size"]
|
||||
|
||||
# delivery: MSE loss(回归任务,不需要类权重)
|
||||
if "delivery" in targets and mask["delivery"].any():
|
||||
m = mask["delivery"]
|
||||
pred = outputs["delivery"][m].squeeze(-1)
|
||||
tgt = targets["delivery"][m]
|
||||
losses["delivery"] = F.mse_loss(pred, tgt)
|
||||
total_loss = total_loss + task_weights.delivery * losses["delivery"]
|
||||
|
||||
# pdi: CrossEntropy with class weights
|
||||
if "pdi" in targets and mask["pdi"].any():
|
||||
m = mask["pdi"]
|
||||
pred = outputs["pdi"][m]
|
||||
tgt = targets["pdi"][m]
|
||||
weight = class_weights.pdi.to(device) if class_weights.pdi is not None else None
|
||||
losses["pdi"] = F.cross_entropy(pred, tgt, weight=weight)
|
||||
total_loss = total_loss + task_weights.pdi * losses["pdi"]
|
||||
|
||||
# ee: CrossEntropy with class weights
|
||||
if "ee" in targets and mask["ee"].any():
|
||||
m = mask["ee"]
|
||||
pred = outputs["ee"][m]
|
||||
tgt = targets["ee"][m]
|
||||
weight = class_weights.ee.to(device) if class_weights.ee is not None else None
|
||||
losses["ee"] = F.cross_entropy(pred, tgt, weight=weight)
|
||||
total_loss = total_loss + task_weights.ee * losses["ee"]
|
||||
|
||||
# toxic: CrossEntropy with class weights
|
||||
if "toxic" in targets and mask["toxic"].any():
|
||||
m = mask["toxic"]
|
||||
pred = outputs["toxic"][m]
|
||||
tgt = targets["toxic"][m]
|
||||
weight = class_weights.toxic.to(device) if class_weights.toxic is not None else None
|
||||
losses["toxic"] = F.cross_entropy(pred, tgt, weight=weight)
|
||||
total_loss = total_loss + task_weights.toxic * losses["toxic"]
|
||||
|
||||
# biodist: KL divergence(分布任务,不需要类权重)
|
||||
if "biodist" in targets and mask["biodist"].any():
|
||||
m = mask["biodist"]
|
||||
pred = outputs["biodist"][m]
|
||||
tgt = targets["biodist"][m]
|
||||
losses["biodist"] = F.kl_div(
|
||||
pred.log().clamp(min=-100),
|
||||
tgt,
|
||||
reduction="batchmean",
|
||||
)
|
||||
total_loss = total_loss + task_weights.biodist * losses["biodist"]
|
||||
|
||||
return total_loss, losses
|
||||
|
||||
|
||||
def train_epoch_balanced(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
task_weights: Optional[LossWeightsBalanced] = None,
|
||||
class_weights: Optional[ClassWeights] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""带类权重的训练一个 epoch"""
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
|
||||
n_batches = 0
|
||||
|
||||
for batch in tqdm(loader, desc="Training", leave=False):
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
targets = {k: v.to(device) for k, v in batch["targets"].items()}
|
||||
mask = {k: v.to(device) for k, v in batch["mask"].items()}
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(smiles, tabular)
|
||||
loss, losses = compute_multitask_loss_balanced(
|
||||
outputs, targets, mask, task_weights, class_weights
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
for k, v in losses.items():
|
||||
task_losses[k] += v.item()
|
||||
n_batches += 1
|
||||
|
||||
return {
|
||||
"loss": total_loss / n_batches,
|
||||
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_balanced(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
device: torch.device,
|
||||
task_weights: Optional[LossWeightsBalanced] = None,
|
||||
class_weights: Optional[ClassWeights] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""带类权重的验证"""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
|
||||
n_batches = 0
|
||||
|
||||
# 用于计算准确率
|
||||
correct = {k: 0 for k in ["pdi", "ee", "toxic"]}
|
||||
total = {k: 0 for k in ["pdi", "ee", "toxic"]}
|
||||
|
||||
for batch in tqdm(loader, desc="Validating", leave=False):
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
targets = {k: v.to(device) for k, v in batch["targets"].items()}
|
||||
mask = {k: v.to(device) for k, v in batch["mask"].items()}
|
||||
|
||||
outputs = model(smiles, tabular)
|
||||
loss, losses = compute_multitask_loss_balanced(
|
||||
outputs, targets, mask, task_weights, class_weights
|
||||
)
|
||||
|
||||
total_loss += loss.item()
|
||||
for k, v in losses.items():
|
||||
task_losses[k] += v.item()
|
||||
n_batches += 1
|
||||
|
||||
# 计算分类准确率
|
||||
for k in ["pdi", "ee", "toxic"]:
|
||||
if k in targets and mask[k].any():
|
||||
m = mask[k]
|
||||
pred = outputs[k][m].argmax(dim=-1)
|
||||
tgt = targets[k][m]
|
||||
correct[k] += (pred == tgt).sum().item()
|
||||
total[k] += m.sum().item()
|
||||
|
||||
metrics = {
|
||||
"loss": total_loss / n_batches,
|
||||
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
|
||||
}
|
||||
|
||||
# 添加准确率
|
||||
for k in ["pdi", "ee", "toxic"]:
|
||||
if total[k] > 0:
|
||||
metrics[f"acc_{k}"] = correct[k] / total[k]
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class EarlyStoppingBalanced:
|
||||
"""早停机制(与 trainer.py 的 EarlyStopping 兼容)"""
|
||||
|
||||
def __init__(self, patience: int = 10, min_delta: float = 0.0):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.best_loss = float("inf")
|
||||
self.best_epoch = 0
|
||||
self.should_stop = False
|
||||
|
||||
def __call__(self, val_loss: float, epoch: int = 0) -> bool:
|
||||
if val_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = val_loss
|
||||
self.best_epoch = epoch
|
||||
self.counter = 0
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.should_stop = True
|
||||
return self.should_stop
|
||||
|
||||
def get_best_epoch(self) -> int:
|
||||
"""获取最佳 epoch(1-indexed)"""
|
||||
return self.best_epoch + 1
|
||||
|
||||
|
||||
def train_with_early_stopping(
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
device: torch.device,
|
||||
lr: float = 1e-4,
|
||||
weight_decay: float = 1e-5,
|
||||
epochs: int = 100,
|
||||
patience: int = 15,
|
||||
task_weights: Optional[LossWeightsBalanced] = None,
|
||||
class_weights: Optional[ClassWeights] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
带早停的完整训练流程。
|
||||
|
||||
Returns:
|
||||
Dict with keys: history, best_val_loss, best_epoch, best_state
|
||||
"""
|
||||
model = model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="min", factor=0.5, patience=5
|
||||
)
|
||||
early_stopping = EarlyStoppingBalanced(patience=patience)
|
||||
|
||||
history = {"train": [], "val": []}
|
||||
best_val_loss = float("inf")
|
||||
best_state = None
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Train
|
||||
train_metrics = train_epoch_balanced(
|
||||
model, train_loader, optimizer, device, task_weights, class_weights
|
||||
)
|
||||
|
||||
# Validate
|
||||
val_metrics = validate_balanced(
|
||||
model, val_loader, device, task_weights, class_weights
|
||||
)
|
||||
|
||||
history["train"].append(train_metrics)
|
||||
history["val"].append(val_metrics)
|
||||
|
||||
# Learning rate scheduling
|
||||
scheduler.step(val_metrics["loss"])
|
||||
|
||||
# Save best model
|
||||
if val_metrics["loss"] < best_val_loss:
|
||||
best_val_loss = val_metrics["loss"]
|
||||
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||
|
||||
# Early stopping
|
||||
if early_stopping(val_metrics["loss"], epoch):
|
||||
break
|
||||
|
||||
# Restore best model
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
return {
|
||||
"history": history,
|
||||
"best_val_loss": best_val_loss,
|
||||
"best_epoch": early_stopping.get_best_epoch(),
|
||||
"best_state": best_state,
|
||||
"epochs_trained": len(history["train"]),
|
||||
}
|
||||
|
||||
|
||||
def train_fixed_epochs(
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: Optional[DataLoader],
|
||||
device: torch.device,
|
||||
lr: float = 1e-4,
|
||||
weight_decay: float = 1e-5,
|
||||
epochs: int = 50,
|
||||
task_weights: Optional[LossWeightsBalanced] = None,
|
||||
class_weights: Optional[ClassWeights] = None,
|
||||
use_cosine_annealing: bool = True,
|
||||
use_swa: bool = False,
|
||||
swa_start_epoch: Optional[int] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
固定 epoch 数的训练(不使用 early stopping)。
|
||||
|
||||
用于外层 CV 训练和最终训练。
|
||||
|
||||
Args:
|
||||
model: 模型
|
||||
train_loader: 训练数据
|
||||
val_loader: 验证数据(可选,仅用于监控)
|
||||
device: 设备
|
||||
lr: 学习率
|
||||
weight_decay: 权重衰减
|
||||
epochs: 训练轮数
|
||||
task_weights: 任务权重
|
||||
class_weights: 类权重
|
||||
use_cosine_annealing: 是否使用 CosineAnnealingLR
|
||||
use_swa: 是否使用 SWA
|
||||
swa_start_epoch: SWA 开始的 epoch(默认为 epochs * 0.75)
|
||||
|
||||
Returns:
|
||||
Dict with keys: history, final_state
|
||||
"""
|
||||
model = model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
if use_cosine_annealing:
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
# SWA setup
|
||||
swa_model = None
|
||||
swa_scheduler = None
|
||||
if use_swa:
|
||||
from torch.optim.swa_utils import AveragedModel, SWALR
|
||||
swa_model = AveragedModel(model)
|
||||
swa_start = swa_start_epoch or int(epochs * 0.75)
|
||||
swa_scheduler = SWALR(optimizer, swa_lr=lr * 0.1)
|
||||
|
||||
history = {"train": [], "val": []}
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Train
|
||||
train_metrics = train_epoch_balanced(
|
||||
model, train_loader, optimizer, device, task_weights, class_weights
|
||||
)
|
||||
history["train"].append(train_metrics)
|
||||
|
||||
# Validate (optional)
|
||||
if val_loader is not None:
|
||||
val_metrics = validate_balanced(
|
||||
model, val_loader, device, task_weights, class_weights
|
||||
)
|
||||
history["val"].append(val_metrics)
|
||||
|
||||
# Scheduler step
|
||||
if use_swa and epoch >= swa_start:
|
||||
swa_model.update_parameters(model)
|
||||
swa_scheduler.step()
|
||||
elif scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
# Finalize SWA
|
||||
final_state = None
|
||||
if use_swa and swa_model is not None:
|
||||
# Update batch normalization statistics
|
||||
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
|
||||
final_state = {k: v.cpu().clone() for k, v in swa_model.module.state_dict().items()}
|
||||
else:
|
||||
final_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||
|
||||
return {
|
||||
"history": history,
|
||||
"final_state": final_state,
|
||||
"epochs_trained": epochs,
|
||||
}
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
models/model.pt
BIN
models/model.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,10 @@
|
||||
{
|
||||
"d_model": 512,
|
||||
"num_heads": 8,
|
||||
"n_attn_layers": 5,
|
||||
"fusion_strategy": "attention",
|
||||
"head_hidden_dim": 128,
|
||||
"dropout": 0.14666736316838325,
|
||||
"lr": 0.0001295888795454003,
|
||||
"weight_decay": 7.732380983243132e-05
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
{"epoch_mean": 13}
|
||||
122
models/nested_cv/20260130_183653/outer_fold_0/history.json
Normal file
122
models/nested_cv/20260130_183653/outer_fold_0/history.json
Normal file
@ -0,0 +1,122 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 9.688567074862393,
|
||||
"loss_size": 4.525775753638961,
|
||||
"loss_pdi": 1.2179414359006016,
|
||||
"loss_ee": 1.091066924008456,
|
||||
"loss_delivery": 1.1003740294413134,
|
||||
"loss_biodist": 1.045865768736059,
|
||||
"loss_toxic": 0.7075429206544702
|
||||
},
|
||||
{
|
||||
"loss": 4.477882081812078,
|
||||
"loss_size": 0.30304074490612204,
|
||||
"loss_pdi": 0.9405507716265592,
|
||||
"loss_ee": 0.9982529336755926,
|
||||
"loss_delivery": 1.1088530285791918,
|
||||
"loss_biodist": 0.6725774082270536,
|
||||
"loss_toxic": 0.45460730520161713
|
||||
},
|
||||
{
|
||||
"loss": 3.7592489285902544,
|
||||
"loss_size": 0.3136644837531177,
|
||||
"loss_pdi": 0.7663570263169028,
|
||||
"loss_ee": 0.9424018914049322,
|
||||
"loss_delivery": 0.9587065902623263,
|
||||
"loss_biodist": 0.5116219466382806,
|
||||
"loss_toxic": 0.2664969251914458
|
||||
},
|
||||
{
|
||||
"loss": 3.2862942435524682,
|
||||
"loss_size": 0.25651484592394397,
|
||||
"loss_pdi": 0.729118591005152,
|
||||
"loss_ee": 0.8862769928845492,
|
||||
"loss_delivery": 0.8572801744396036,
|
||||
"loss_biodist": 0.42606663974848663,
|
||||
"loss_toxic": 0.13103695366192947
|
||||
},
|
||||
{
|
||||
"loss": 3.032014153220437,
|
||||
"loss_size": 0.2605769471688704,
|
||||
"loss_pdi": 0.6875471039251848,
|
||||
"loss_ee": 0.8450987447391857,
|
||||
"loss_delivery": 0.7942117723551664,
|
||||
"loss_biodist": 0.3333369114182212,
|
||||
"loss_toxic": 0.11124271154403687
|
||||
},
|
||||
{
|
||||
"loss": 2.747604175047441,
|
||||
"loss_size": 0.23051185499538074,
|
||||
"loss_pdi": 0.6304752420295369,
|
||||
"loss_ee": 0.8027611483227123,
|
||||
"loss_delivery": 0.6954045072197914,
|
||||
"loss_biodist": 0.31399699503725226,
|
||||
"loss_toxic": 0.07445447621020404
|
||||
},
|
||||
{
|
||||
"loss": 2.496532602743669,
|
||||
"loss_size": 0.21782836656678806,
|
||||
"loss_pdi": 0.6259297322143208,
|
||||
"loss_ee": 0.7789664485237815,
|
||||
"loss_delivery": 0.5751941861076788,
|
||||
"loss_biodist": 0.25595076788555493,
|
||||
"loss_toxic": 0.04266304531219331
|
||||
},
|
||||
{
|
||||
"loss": 2.474623289975253,
|
||||
"loss_size": 0.2284239645708691,
|
||||
"loss_pdi": 0.6423451900482178,
|
||||
"loss_ee": 0.7340371066873724,
|
||||
"loss_delivery": 0.5827173766764727,
|
||||
"loss_biodist": 0.24910570003769614,
|
||||
"loss_toxic": 0.03799399394880642
|
||||
},
|
||||
{
|
||||
"loss": 2.3557864644310693,
|
||||
"loss_size": 0.197183218869296,
|
||||
"loss_pdi": 0.5410721220753409,
|
||||
"loss_ee": 0.7328030331568285,
|
||||
"loss_delivery": 0.6107787185094573,
|
||||
"loss_biodist": 0.2399107339707288,
|
||||
"loss_toxic": 0.03403859864920378
|
||||
},
|
||||
{
|
||||
"loss": 2.187597805803472,
|
||||
"loss_size": 0.19923549348657782,
|
||||
"loss_pdi": 0.514803715727546,
|
||||
"loss_ee": 0.7146885395050049,
|
||||
"loss_delivery": 0.49052428555759514,
|
||||
"loss_biodist": 0.22473187744617462,
|
||||
"loss_toxic": 0.043613933196121994
|
||||
},
|
||||
{
|
||||
"loss": 2.1395224874669854,
|
||||
"loss_size": 0.20194762064652008,
|
||||
"loss_pdi": 0.5207281329415061,
|
||||
"loss_ee": 0.6976469484242526,
|
||||
"loss_delivery": 0.4825741987336766,
|
||||
"loss_biodist": 0.21454968913034958,
|
||||
"loss_toxic": 0.022075910629196602
|
||||
},
|
||||
{
|
||||
"loss": 2.0708962462165137,
|
||||
"loss_size": 0.18447093801064926,
|
||||
"loss_pdi": 0.52463944662701,
|
||||
"loss_ee": 0.6666911081834273,
|
||||
"loss_delivery": 0.4482487521388314,
|
||||
"loss_biodist": 0.21374160457741131,
|
||||
"loss_toxic": 0.033104426227509975
|
||||
},
|
||||
{
|
||||
"loss": 2.065277782353488,
|
||||
"loss_size": 0.17790686813267795,
|
||||
"loss_pdi": 0.5256167894059961,
|
||||
"loss_ee": 0.6798818870024248,
|
||||
"loss_delivery": 0.43902007693594153,
|
||||
"loss_biodist": 0.21479454501108688,
|
||||
"loss_toxic": 0.02805758648636666
|
||||
}
|
||||
],
|
||||
"val": []
|
||||
}
|
||||
3
models/nested_cv/20260130_183653/outer_fold_0/model.pt
Normal file
3
models/nested_cv/20260130_183653/outer_fold_0/model.pt
Normal file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:04f3103a630de23aeb971629aa769e1a9cdb3c5247c193eefb808c6a4e17b9cc
|
||||
size 133133308
|
||||
Binary file not shown.
@ -0,0 +1 @@
|
||||
{"outer_train_idx": [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 32, 33, 34, 35, 37, 38, 39, 40, 41, 43, 44, 45, 46, 47, 48, 49, 50, 52, 53, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 102, 103, 104, 105, 106, 107, 108, 109, 110, 112, 113, 114, 115, 117, 118, 119, 121, 123, 126, 127, 128, 129, 131, 132, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 146, 147, 148, 149, 150, 152, 153, 155, 156, 158, 159, 163, 167, 169, 170, 171, 172, 173, 175, 176, 177, 178, 179, 180, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 207, 208, 211, 214, 215, 216, 217, 218, 219, 221, 222, 223, 224, 227, 228, 229, 230, 231, 233, 234, 236, 237, 238, 239, 243, 244, 245, 247, 248, 249, 251, 252, 255, 256, 257, 258, 260, 261, 262, 263, 264, 265, 266, 268, 269, 270, 271, 272, 273, 274, 275, 276, 279, 282, 283, 284, 286, 287, 288, 290, 291, 292, 293, 294, 295, 297, 298, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 319, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 335, 337, 338, 339, 340, 341, 343, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 358, 359, 360, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 379, 380, 381, 382, 383, 384, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 398, 399, 400, 402, 403, 404, 408, 409, 410, 411, 412, 413, 415, 416, 417, 418, 419, 420, 421, 422, 423, 425, 426, 428, 429, 432, 433], "outer_test_idx": [7, 12, 13, 25, 31, 36, 42, 51, 54, 55, 60, 72, 76, 80, 101, 111, 116, 120, 122, 124, 125, 130, 133, 145, 151, 154, 157, 160, 161, 162, 164, 165, 166, 168, 174, 181, 206, 209, 210, 212, 213, 220, 225, 226, 232, 235, 240, 241, 242, 246, 250, 253, 254, 259, 267, 277, 278, 280, 281, 285, 289, 296, 299, 300, 318, 320, 321, 334, 336, 342, 344, 356, 357, 361, 377, 378, 385, 397, 401, 405, 406, 407, 414, 424, 427, 430, 431]}
|
||||
@ -0,0 +1,42 @@
|
||||
{
|
||||
"size": {
|
||||
"n_samples": 87,
|
||||
"mse": 0.26366087871128757,
|
||||
"rmse": 0.5134791901443403,
|
||||
"mae": 0.25157783223294666,
|
||||
"r2": 0.21208517006410577
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 61,
|
||||
"mse": 0.40443344562739025,
|
||||
"rmse": 0.63595082013265,
|
||||
"mae": 0.3928790920429298,
|
||||
"r2": 0.2300258531372983
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.7241379310344828,
|
||||
"precision": 0.35141509433962265,
|
||||
"recall": 0.35351966873706003,
|
||||
"f1": 0.348405985686402
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6666666666666666,
|
||||
"precision": 0.6188811188811189,
|
||||
"recall": 0.6375291375291375,
|
||||
"f1": 0.6217948717948718
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 62,
|
||||
"accuracy": 0.967741935483871,
|
||||
"precision": 0.8,
|
||||
"recall": 0.9830508474576272,
|
||||
"f1": 0.8663793103448275
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 61,
|
||||
"kl_divergence": 0.14776465556036145,
|
||||
"js_divergence": 0.03926150329301917
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,10 @@
|
||||
{
|
||||
"d_model": 512,
|
||||
"num_heads": 4,
|
||||
"n_attn_layers": 2,
|
||||
"fusion_strategy": "avg",
|
||||
"head_hidden_dim": 64,
|
||||
"dropout": 0.05188345993471756,
|
||||
"lr": 4.21188892865021e-05,
|
||||
"weight_decay": 4.086499445232577e-05
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
{"epoch_mean": 23}
|
||||
212
models/nested_cv/20260130_183653/outer_fold_1/history.json
Normal file
212
models/nested_cv/20260130_183653/outer_fold_1/history.json
Normal file
@ -0,0 +1,212 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 19.564044865694914,
|
||||
"loss_size": 14.219423033974387,
|
||||
"loss_pdi": 1.3700515248558738,
|
||||
"loss_ee": 1.0948690609498457,
|
||||
"loss_delivery": 0.925118706443093,
|
||||
"loss_biodist": 1.2897993542931296,
|
||||
"loss_toxic": 0.6647828817367554
|
||||
},
|
||||
{
|
||||
"loss": 9.569591695612127,
|
||||
"loss_size": 4.443219217387113,
|
||||
"loss_pdi": 1.3034031066027554,
|
||||
"loss_ee": 1.0691871643066406,
|
||||
"loss_delivery": 0.986666976050897,
|
||||
"loss_biodist": 1.1807570999318904,
|
||||
"loss_toxic": 0.5863581760363146
|
||||
},
|
||||
{
|
||||
"loss": 5.09315148266879,
|
||||
"loss_size": 0.4028705805540085,
|
||||
"loss_pdi": 1.2596685236150569,
|
||||
"loss_ee": 1.0527758598327637,
|
||||
"loss_delivery": 0.8668525652451948,
|
||||
"loss_biodist": 1.0210744521834634,
|
||||
"loss_toxic": 0.48990951072085986
|
||||
},
|
||||
{
|
||||
"loss": 4.664039200002497,
|
||||
"loss_size": 0.26930826157331467,
|
||||
"loss_pdi": 1.1607397361235186,
|
||||
"loss_ee": 1.0082263296300715,
|
||||
"loss_delivery": 0.9969787990505045,
|
||||
"loss_biodist": 0.8467015407302163,
|
||||
"loss_toxic": 0.3820846405896274
|
||||
},
|
||||
{
|
||||
"loss": 4.224615703929555,
|
||||
"loss_size": 0.28240104629234836,
|
||||
"loss_pdi": 1.048582375049591,
|
||||
"loss_ee": 0.9691753116520968,
|
||||
"loss_delivery": 0.8995198593898253,
|
||||
"loss_biodist": 0.7173566547307101,
|
||||
"loss_toxic": 0.3075804940678857
|
||||
},
|
||||
{
|
||||
"loss": 3.810542041605169,
|
||||
"loss_size": 0.23119159516963092,
|
||||
"loss_pdi": 0.9792777408253063,
|
||||
"loss_ee": 0.9330252029679038,
|
||||
"loss_delivery": 0.8217983476140283,
|
||||
"loss_biodist": 0.6248252906582572,
|
||||
"loss_toxic": 0.22042379054156216
|
||||
},
|
||||
{
|
||||
"loss": 3.5365114862268623,
|
||||
"loss_size": 0.21511441333727402,
|
||||
"loss_pdi": 0.9201341759074818,
|
||||
"loss_ee": 0.9007513360543684,
|
||||
"loss_delivery": 0.8033261312679811,
|
||||
"loss_biodist": 0.5451555441726338,
|
||||
"loss_toxic": 0.15202991325746884
|
||||
},
|
||||
{
|
||||
"loss": 3.405807386745106,
|
||||
"loss_size": 0.20161619240587408,
|
||||
"loss_pdi": 0.8445044376633384,
|
||||
"loss_ee": 0.8931786905635487,
|
||||
"loss_delivery": 0.831847377798774,
|
||||
"loss_biodist": 0.49607704173434863,
|
||||
"loss_toxic": 0.13858359239318155
|
||||
},
|
||||
{
|
||||
"loss": 3.103480577468872,
|
||||
"loss_size": 0.18636523119427942,
|
||||
"loss_pdi": 0.8065886064009233,
|
||||
"loss_ee": 0.8574360067194159,
|
||||
"loss_delivery": 0.6954484595493837,
|
||||
"loss_biodist": 0.45561750639568677,
|
||||
"loss_toxic": 0.10202483257109468
|
||||
},
|
||||
{
|
||||
"loss": 3.0312788052992388,
|
||||
"loss_size": 0.19699390774423425,
|
||||
"loss_pdi": 0.7626179348338734,
|
||||
"loss_ee": 0.8408515670082786,
|
||||
"loss_delivery": 0.7123599113388495,
|
||||
"loss_biodist": 0.41937256130305206,
|
||||
"loss_toxic": 0.09908294406804172
|
||||
},
|
||||
{
|
||||
"loss": 2.8196144104003906,
|
||||
"loss_size": 0.1769183948636055,
|
||||
"loss_pdi": 0.717121189290827,
|
||||
"loss_ee": 0.8330178802663629,
|
||||
"loss_delivery": 0.6163532761010256,
|
||||
"loss_biodist": 0.3913883079182018,
|
||||
"loss_toxic": 0.08481534672054378
|
||||
},
|
||||
{
|
||||
"loss": 2.7494440295479516,
|
||||
"loss_size": 0.17335935140197928,
|
||||
"loss_pdi": 0.7235767028548501,
|
||||
"loss_ee": 0.812805939804424,
|
||||
"loss_delivery": 0.5859575948931954,
|
||||
"loss_biodist": 0.38201813806187024,
|
||||
"loss_toxic": 0.07172633301128041
|
||||
},
|
||||
{
|
||||
"loss": 2.6807472705841064,
|
||||
"loss_size": 0.20867897028272803,
|
||||
"loss_pdi": 0.6913418119603937,
|
||||
"loss_ee": 0.8103034550493414,
|
||||
"loss_delivery": 0.5472286993807013,
|
||||
"loss_biodist": 0.34767021103338763,
|
||||
"loss_toxic": 0.07552417537028139
|
||||
},
|
||||
{
|
||||
"loss": 2.5547089793465356,
|
||||
"loss_size": 0.17158158326690848,
|
||||
"loss_pdi": 0.6350900205698881,
|
||||
"loss_ee": 0.7878076163205233,
|
||||
"loss_delivery": 0.5484779531305487,
|
||||
"loss_biodist": 0.34313417022878473,
|
||||
"loss_toxic": 0.06861767376011069
|
||||
},
|
||||
{
|
||||
"loss": 2.557967185974121,
|
||||
"loss_size": 0.14963022822683508,
|
||||
"loss_pdi": 0.6706653941761364,
|
||||
"loss_ee": 0.7864666526967828,
|
||||
"loss_delivery": 0.5519421967593107,
|
||||
"loss_biodist": 0.3387457403269681,
|
||||
"loss_toxic": 0.06051696803082119
|
||||
},
|
||||
{
|
||||
"loss": 2.5365889939394863,
|
||||
"loss_size": 0.17580546641891653,
|
||||
"loss_pdi": 0.6637366847558455,
|
||||
"loss_ee": 0.7831195484508168,
|
||||
"loss_delivery": 0.5293790704824708,
|
||||
"loss_biodist": 0.3297005011276765,
|
||||
"loss_toxic": 0.054847732186317444
|
||||
},
|
||||
{
|
||||
"loss": 2.51393402706493,
|
||||
"loss_size": 0.1777868609536778,
|
||||
"loss_pdi": 0.6462894515557722,
|
||||
"loss_ee": 0.7651460604234175,
|
||||
"loss_delivery": 0.5493429905988954,
|
||||
"loss_biodist": 0.317649554122578,
|
||||
"loss_toxic": 0.05771913768892938
|
||||
},
|
||||
{
|
||||
"loss": 2.425347761674361,
|
||||
"loss_size": 0.14695054224946283,
|
||||
"loss_pdi": 0.6508165923031893,
|
||||
"loss_ee": 0.7596850720318881,
|
||||
"loss_delivery": 0.49562479284676636,
|
||||
"loss_biodist": 0.31901306862180884,
|
||||
"loss_toxic": 0.053257653659040276
|
||||
},
|
||||
{
|
||||
"loss": 2.450036742470481,
|
||||
"loss_size": 0.15424400500275873,
|
||||
"loss_pdi": 0.655444394458424,
|
||||
"loss_ee": 0.7591585137627341,
|
||||
"loss_delivery": 0.5164471390572462,
|
||||
"loss_biodist": 0.3094088543545116,
|
||||
"loss_toxic": 0.05533381686969237
|
||||
},
|
||||
{
|
||||
"loss": 2.3860875476490366,
|
||||
"loss_size": 0.1556895158507607,
|
||||
"loss_pdi": 0.6400747786868702,
|
||||
"loss_ee": 0.7600220929492604,
|
||||
"loss_delivery": 0.46673478253863077,
|
||||
"loss_biodist": 0.3111781979149038,
|
||||
"loss_toxic": 0.05238820222968405
|
||||
},
|
||||
{
|
||||
"loss": 2.3954180804165928,
|
||||
"loss_size": 0.1462622725150802,
|
||||
"loss_pdi": 0.619796259836717,
|
||||
"loss_ee": 0.7609411152926359,
|
||||
"loss_delivery": 0.5080444128675894,
|
||||
"loss_biodist": 0.30634408511898736,
|
||||
"loss_toxic": 0.05402998389168219
|
||||
},
|
||||
{
|
||||
"loss": 2.3838103034279565,
|
||||
"loss_size": 0.13035081394694067,
|
||||
"loss_pdi": 0.6347457062114369,
|
||||
"loss_ee": 0.7560921203006398,
|
||||
"loss_delivery": 0.4973590536551042,
|
||||
"loss_biodist": 0.3073451383547349,
|
||||
"loss_toxic": 0.05791747265241363
|
||||
},
|
||||
{
|
||||
"loss": 2.4303664727644487,
|
||||
"loss_size": 0.1789045144211162,
|
||||
"loss_pdi": 0.631924033164978,
|
||||
"loss_ee": 0.7499593008648265,
|
||||
"loss_delivery": 0.502868037332188,
|
||||
"loss_biodist": 0.3101512383330952,
|
||||
"loss_toxic": 0.05655933002179319
|
||||
}
|
||||
],
|
||||
"val": []
|
||||
}
|
||||
3
models/nested_cv/20260130_183653/outer_fold_1/model.pt
Normal file
3
models/nested_cv/20260130_183653/outer_fold_1/model.pt
Normal file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b1c9190b09f7ab586a7d5190b74452f4be67c5d2fb753aae202f9b723523ef00
|
||||
size 55606866
|
||||
Binary file not shown.
@ -0,0 +1 @@
|
||||
{"outer_train_idx": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 60, 61, 62, 64, 65, 70, 71, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 84, 87, 89, 90, 91, 92, 93, 96, 98, 99, 101, 102, 103, 105, 106, 108, 111, 112, 114, 115, 116, 119, 120, 121, 122, 123, 124, 125, 126, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 148, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 164, 165, 166, 167, 168, 171, 172, 173, 174, 175, 176, 177, 178, 179, 181, 185, 186, 187, 190, 191, 192, 196, 198, 203, 204, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 246, 248, 249, 250, 251, 253, 254, 255, 256, 257, 259, 260, 261, 262, 264, 265, 266, 267, 268, 269, 270, 271, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 295, 296, 298, 299, 300, 301, 302, 303, 304, 306, 307, 308, 309, 310, 312, 313, 314, 316, 317, 318, 319, 320, 321, 322, 323, 324, 326, 327, 329, 330, 331, 332, 334, 335, 336, 337, 340, 341, 342, 344, 346, 347, 348, 349, 350, 351, 352, 353, 354, 356, 357, 359, 360, 361, 362, 363, 364, 365, 366, 369, 370, 372, 374, 376, 377, 378, 380, 381, 382, 384, 385, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 421, 422, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433], "outer_test_idx": [29, 32, 35, 46, 48, 53, 63, 66, 67, 68, 69, 78, 83, 85, 86, 88, 94, 95, 97, 100, 104, 107, 109, 110, 113, 117, 118, 127, 128, 146, 147, 149, 150, 163, 169, 170, 180, 182, 183, 184, 188, 189, 193, 194, 195, 197, 199, 200, 201, 202, 205, 214, 228, 229, 245, 247, 252, 258, 263, 272, 293, 294, 297, 305, 311, 315, 325, 328, 333, 338, 339, 343, 345, 355, 358, 367, 368, 371, 373, 375, 379, 383, 386, 408, 409, 420, 423]}
|
||||
@ -0,0 +1,42 @@
|
||||
{
|
||||
"size": {
|
||||
"n_samples": 87,
|
||||
"mse": 0.20102046206409732,
|
||||
"rmse": 0.44835305515196094,
|
||||
"mae": 0.2436335881551107,
|
||||
"r2": 0.1900957049146058
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 62,
|
||||
"mse": 0.5899936276993041,
|
||||
"rmse": 0.7681104267612203,
|
||||
"mae": 0.46896366539576484,
|
||||
"r2": 0.4017743543115936
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6666666666666666,
|
||||
"precision": 0.41556437389770723,
|
||||
"recall": 0.4042119565217391,
|
||||
"f1": 0.40777777777777774
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6666666666666666,
|
||||
"precision": 0.6414141414141414,
|
||||
"recall": 0.6782661782661782,
|
||||
"f1": 0.6387485970819304
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 62,
|
||||
"accuracy": 1.0,
|
||||
"precision": 1.0,
|
||||
"recall": 1.0,
|
||||
"f1": 1.0
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 62,
|
||||
"kl_divergence": 0.30758161166563297,
|
||||
"js_divergence": 0.08759221465023677
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,10 @@
|
||||
{
|
||||
"d_model": 512,
|
||||
"num_heads": 8,
|
||||
"n_attn_layers": 4,
|
||||
"fusion_strategy": "attention",
|
||||
"head_hidden_dim": 64,
|
||||
"dropout": 0.11433976976282646,
|
||||
"lr": 5.3015812445144804e-05,
|
||||
"weight_decay": 6.704431817743382e-06
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
{"epoch_mean": 18}
|
||||
167
models/nested_cv/20260130_183653/outer_fold_2/history.json
Normal file
167
models/nested_cv/20260130_183653/outer_fold_2/history.json
Normal file
@ -0,0 +1,167 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 16.328427488153633,
|
||||
"loss_size": 10.961490392684937,
|
||||
"loss_pdi": 1.3189676566557451,
|
||||
"loss_ee": 1.0652603994716296,
|
||||
"loss_delivery": 1.093609056012197,
|
||||
"loss_biodist": 1.2020217722112483,
|
||||
"loss_toxic": 0.6870784759521484
|
||||
},
|
||||
{
|
||||
"loss": 5.557661620053378,
|
||||
"loss_size": 0.6549029702490027,
|
||||
"loss_pdi": 1.2172227664427324,
|
||||
"loss_ee": 1.0503052473068237,
|
||||
"loss_delivery": 0.9982107078487222,
|
||||
"loss_biodist": 1.0446247458457947,
|
||||
"loss_toxic": 0.5923952026800676
|
||||
},
|
||||
{
|
||||
"loss": 4.66534686088562,
|
||||
"loss_size": 0.40050424906340515,
|
||||
"loss_pdi": 0.9867187413302335,
|
||||
"loss_ee": 0.9837820909240029,
|
||||
"loss_delivery": 1.079543113708496,
|
||||
"loss_biodist": 0.7991420409896157,
|
||||
"loss_toxic": 0.41565650972453033
|
||||
},
|
||||
{
|
||||
"loss": 4.042153250087392,
|
||||
"loss_size": 0.39615295827388763,
|
||||
"loss_pdi": 0.8758815581148321,
|
||||
"loss_ee": 0.9381269162351434,
|
||||
"loss_delivery": 0.9179368994452737,
|
||||
"loss_biodist": 0.6232685229995034,
|
||||
"loss_toxic": 0.2907863679257306
|
||||
},
|
||||
{
|
||||
"loss": 3.5052005377682773,
|
||||
"loss_size": 0.3407063511284915,
|
||||
"loss_pdi": 0.8075345104390924,
|
||||
"loss_ee": 0.8666415322910656,
|
||||
"loss_delivery": 0.821596086025238,
|
||||
"loss_biodist": 0.49901550195433875,
|
||||
"loss_toxic": 0.16970657895911823
|
||||
},
|
||||
{
|
||||
"loss": 3.227808865633878,
|
||||
"loss_size": 0.3172220086509531,
|
||||
"loss_pdi": 0.7554353800686923,
|
||||
"loss_ee": 0.8366058685562827,
|
||||
"loss_delivery": 0.7592829249121926,
|
||||
"loss_biodist": 0.41754513708027924,
|
||||
"loss_toxic": 0.14171750030734323
|
||||
},
|
||||
{
|
||||
"loss": 3.015385866165161,
|
||||
"loss_size": 0.3014552891254425,
|
||||
"loss_pdi": 0.7061856659975919,
|
||||
"loss_ee": 0.8024854118173773,
|
||||
"loss_delivery": 0.6936024874448776,
|
||||
"loss_biodist": 0.37025675448504364,
|
||||
"loss_toxic": 0.1414001943035559
|
||||
},
|
||||
{
|
||||
"loss": 2.8741718639026987,
|
||||
"loss_size": 0.30145836960185657,
|
||||
"loss_pdi": 0.6485596244985407,
|
||||
"loss_ee": 0.7874462875452909,
|
||||
"loss_delivery": 0.7069157646460966,
|
||||
"loss_biodist": 0.31511187282475556,
|
||||
"loss_toxic": 0.11467998136173595
|
||||
},
|
||||
{
|
||||
"loss": 2.7043218179182573,
|
||||
"loss_size": 0.30878364227034827,
|
||||
"loss_pdi": 0.6446920850060203,
|
||||
"loss_ee": 0.763309196992354,
|
||||
"loss_delivery": 0.6044047027826309,
|
||||
"loss_biodist": 0.2966968539086255,
|
||||
"loss_toxic": 0.08643539994955063
|
||||
},
|
||||
{
|
||||
"loss": 2.592982042919506,
|
||||
"loss_size": 0.301474930210547,
|
||||
"loss_pdi": 0.586654855446382,
|
||||
"loss_ee": 0.7498630989681591,
|
||||
"loss_delivery": 0.5886639884927056,
|
||||
"loss_biodist": 0.2792905040762641,
|
||||
"loss_toxic": 0.08703470907428047
|
||||
},
|
||||
{
|
||||
"loss": 2.475934158671986,
|
||||
"loss_size": 0.3227222941138528,
|
||||
"loss_pdi": 0.5525102777914568,
|
||||
"loss_ee": 0.7191228595646945,
|
||||
"loss_delivery": 0.5444187271324071,
|
||||
"loss_biodist": 0.254928475076502,
|
||||
"loss_toxic": 0.08223151076923717
|
||||
},
|
||||
{
|
||||
"loss": 2.4397390105507593,
|
||||
"loss_size": 0.28487288003618066,
|
||||
"loss_pdi": 0.5408996912566099,
|
||||
"loss_ee": 0.7313735810193148,
|
||||
"loss_delivery": 0.5359987101771615,
|
||||
"loss_biodist": 0.2598937614397569,
|
||||
"loss_toxic": 0.08670033531432803
|
||||
},
|
||||
{
|
||||
"loss": 2.454329328103499,
|
||||
"loss_size": 0.2546617639335719,
|
||||
"loss_pdi": 0.549032907594334,
|
||||
"loss_ee": 0.723411272872578,
|
||||
"loss_delivery": 0.6058986783027649,
|
||||
"loss_biodist": 0.2590403082695874,
|
||||
"loss_toxic": 0.06228439848531376
|
||||
},
|
||||
{
|
||||
"loss": 2.459202441302213,
|
||||
"loss_size": 0.30333411490375345,
|
||||
"loss_pdi": 0.5448255024173043,
|
||||
"loss_ee": 0.7321870706298135,
|
||||
"loss_delivery": 0.5674931427294557,
|
||||
"loss_biodist": 0.2446274757385254,
|
||||
"loss_toxic": 0.06673513505269181
|
||||
},
|
||||
{
|
||||
"loss": 2.4072633764960547,
|
||||
"loss_size": 0.3075956946069544,
|
||||
"loss_pdi": 0.5372236939993772,
|
||||
"loss_ee": 0.6942113583738153,
|
||||
"loss_delivery": 0.541759736158631,
|
||||
"loss_biodist": 0.2502128630876541,
|
||||
"loss_toxic": 0.07625993527472019
|
||||
},
|
||||
{
|
||||
"loss": 2.3814159631729126,
|
||||
"loss_size": 0.2884528948502107,
|
||||
"loss_pdi": 0.5177338475530798,
|
||||
"loss_ee": 0.7055766040628607,
|
||||
"loss_delivery": 0.564139807766134,
|
||||
"loss_biodist": 0.24335274642164056,
|
||||
"loss_toxic": 0.062160092321309174
|
||||
},
|
||||
{
|
||||
"loss": 2.288854566487399,
|
||||
"loss_size": 0.2864968058737842,
|
||||
"loss_pdi": 0.5151266103441065,
|
||||
"loss_ee": 0.6858331777832725,
|
||||
"loss_delivery": 0.492813152345744,
|
||||
"loss_biodist": 0.23307398774407126,
|
||||
"loss_toxic": 0.07551077617840334
|
||||
},
|
||||
{
|
||||
"loss": 2.308527339588512,
|
||||
"loss_size": 0.3206997500224547,
|
||||
"loss_pdi": 0.48468891328031366,
|
||||
"loss_ee": 0.6778702302412554,
|
||||
"loss_delivery": 0.5191753757270899,
|
||||
"loss_biodist": 0.2486932264132933,
|
||||
"loss_toxic": 0.05739984508942474
|
||||
}
|
||||
],
|
||||
"val": []
|
||||
}
|
||||
3
models/nested_cv/20260130_183653/outer_fold_2/model.pt
Normal file
3
models/nested_cv/20260130_183653/outer_fold_2/model.pt
Normal file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6285ba67132f7fff122c780ba4a63a2f52ed2b8a8775342015c1ccf00e2351e9
|
||||
size 107113964
|
||||
Binary file not shown.
@ -0,0 +1 @@
|
||||
{"outer_train_idx": [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 18, 20, 21, 22, 23, 24, 25, 27, 29, 31, 32, 35, 36, 39, 40, 41, 42, 43, 45, 46, 47, 48, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 63, 65, 66, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 80, 81, 83, 84, 85, 86, 87, 88, 89, 91, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 145, 146, 147, 148, 149, 150, 151, 153, 154, 155, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 199, 200, 201, 202, 203, 204, 205, 206, 207, 209, 210, 212, 213, 214, 215, 216, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 234, 235, 239, 240, 241, 242, 243, 244, 245, 246, 247, 249, 250, 252, 253, 254, 255, 258, 259, 260, 262, 263, 264, 265, 267, 268, 272, 274, 275, 276, 277, 278, 280, 281, 283, 284, 285, 286, 287, 289, 290, 291, 293, 294, 296, 297, 298, 299, 300, 301, 305, 308, 309, 310, 311, 313, 314, 315, 317, 318, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 342, 343, 344, 345, 348, 349, 352, 353, 355, 356, 357, 358, 359, 360, 361, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 375, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 392, 395, 396, 397, 399, 400, 401, 403, 404, 405, 406, 407, 408, 409, 410, 412, 413, 414, 415, 416, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431], "outer_test_idx": [3, 14, 15, 17, 19, 26, 28, 30, 33, 34, 37, 38, 44, 49, 50, 62, 64, 73, 79, 82, 90, 92, 93, 105, 106, 121, 132, 136, 144, 152, 156, 158, 173, 187, 198, 208, 211, 217, 218, 233, 236, 237, 238, 248, 251, 256, 257, 261, 266, 269, 270, 271, 273, 279, 282, 288, 292, 295, 302, 303, 304, 306, 307, 312, 316, 319, 340, 341, 346, 347, 350, 351, 354, 362, 374, 376, 390, 391, 393, 394, 398, 402, 411, 417, 418, 432, 433]}
|
||||
@ -0,0 +1,42 @@
|
||||
{
|
||||
"size": {
|
||||
"n_samples": 85,
|
||||
"mse": 0.07855156229299931,
|
||||
"rmse": 0.28027051627490057,
|
||||
"mae": 0.2201253890991211,
|
||||
"r2": 0.009407939412061861
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 62,
|
||||
"mse": 0.4162270771403472,
|
||||
"rmse": 0.645156629928227,
|
||||
"mae": 0.41306305523856635,
|
||||
"r2": 0.37384819758564136
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.7126436781609196,
|
||||
"precision": 0.3963383838383838,
|
||||
"recall": 0.5856060606060606,
|
||||
"f1": 0.43013891331106036
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6781609195402298,
|
||||
"precision": 0.6215366001209922,
|
||||
"recall": 0.65781362712309,
|
||||
"f1": 0.6235867752721687
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 62,
|
||||
"accuracy": 0.967741935483871,
|
||||
"precision": 0.8,
|
||||
"recall": 0.9830508474576272,
|
||||
"f1": 0.8663793103448275
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 62,
|
||||
"kl_divergence": 0.3189032824921648,
|
||||
"js_divergence": 0.07944133611635379
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,10 @@
|
||||
{
|
||||
"d_model": 512,
|
||||
"num_heads": 4,
|
||||
"n_attn_layers": 5,
|
||||
"fusion_strategy": "attention",
|
||||
"head_hidden_dim": 64,
|
||||
"dropout": 0.11746271741188277,
|
||||
"lr": 0.0001939220403760229,
|
||||
"weight_decay": 2.4722550292920085e-06
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
{"epoch_mean": 10}
|
||||
95
models/nested_cv/20260130_183653/outer_fold_3/history.json
Normal file
95
models/nested_cv/20260130_183653/outer_fold_3/history.json
Normal file
@ -0,0 +1,95 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 9.790851376273416,
|
||||
"loss_size": 4.886490476402369,
|
||||
"loss_pdi": 1.3631455573168667,
|
||||
"loss_ee": 1.088216781616211,
|
||||
"loss_delivery": 0.7040194340727546,
|
||||
"loss_biodist": 1.1505104249173945,
|
||||
"loss_toxic": 0.5984687263315375
|
||||
},
|
||||
{
|
||||
"loss": 4.359956351193515,
|
||||
"loss_size": 0.43619183789600025,
|
||||
"loss_pdi": 1.1665386232462795,
|
||||
"loss_ee": 0.9901693842627786,
|
||||
"loss_delivery": 0.6073603630065918,
|
||||
"loss_biodist": 0.7433811046860435,
|
||||
"loss_toxic": 0.41631498661908234
|
||||
},
|
||||
{
|
||||
"loss": 3.756383180618286,
|
||||
"loss_size": 0.4402667825872248,
|
||||
"loss_pdi": 1.0114682045849888,
|
||||
"loss_ee": 0.9900745803659613,
|
||||
"loss_delivery": 0.5448389879681848,
|
||||
"loss_biodist": 0.5919120474295183,
|
||||
"loss_toxic": 0.17782256481322375
|
||||
},
|
||||
{
|
||||
"loss": 3.4846310182051226,
|
||||
"loss_size": 0.46992606737396936,
|
||||
"loss_pdi": 1.000607517632571,
|
||||
"loss_ee": 0.8866635181687095,
|
||||
"loss_delivery": 0.5218790579925884,
|
||||
"loss_biodist": 0.47943955388936127,
|
||||
"loss_toxic": 0.12611534412611614
|
||||
},
|
||||
{
|
||||
"loss": 3.0496469844471323,
|
||||
"loss_size": 0.36307603120803833,
|
||||
"loss_pdi": 0.8360648371956565,
|
||||
"loss_ee": 0.8720264759930697,
|
||||
"loss_delivery": 0.45130031081763183,
|
||||
"loss_biodist": 0.4275714944709431,
|
||||
"loss_toxic": 0.09960775822401047
|
||||
},
|
||||
{
|
||||
"loss": 2.8820750496604224,
|
||||
"loss_size": 0.35799529267983005,
|
||||
"loss_pdi": 0.7969891158017245,
|
||||
"loss_ee": 0.8362645967440172,
|
||||
"loss_delivery": 0.4626781473105604,
|
||||
"loss_biodist": 0.358451091430404,
|
||||
"loss_toxic": 0.0696967878294262
|
||||
},
|
||||
{
|
||||
"loss": 2.645532087846236,
|
||||
"loss_size": 0.3313806341453032,
|
||||
"loss_pdi": 0.7585093595764854,
|
||||
"loss_ee": 0.797249972820282,
|
||||
"loss_delivery": 0.38693068108775397,
|
||||
"loss_biodist": 0.29899653656916186,
|
||||
"loss_toxic": 0.07246483176607978
|
||||
},
|
||||
{
|
||||
"loss": 2.5514606779271904,
|
||||
"loss_size": 0.33785366063768213,
|
||||
"loss_pdi": 0.710435076193376,
|
||||
"loss_ee": 0.7790363105860624,
|
||||
"loss_delivery": 0.3592266860333356,
|
||||
"loss_biodist": 0.2742794535376809,
|
||||
"loss_toxic": 0.09062948763709176
|
||||
},
|
||||
{
|
||||
"loss": 2.4680945114655928,
|
||||
"loss_size": 0.3564724339680238,
|
||||
"loss_pdi": 0.6821299493312836,
|
||||
"loss_ee": 0.7459230910647999,
|
||||
"loss_delivery": 0.36073175072669983,
|
||||
"loss_biodist": 0.26153207096186554,
|
||||
"loss_toxic": 0.061305238187990406
|
||||
},
|
||||
{
|
||||
"loss": 2.3807498541745273,
|
||||
"loss_size": 0.29239929941567505,
|
||||
"loss_pdi": 0.7071054875850677,
|
||||
"loss_ee": 0.7487604834816672,
|
||||
"loss_delivery": 0.34709482369097794,
|
||||
"loss_biodist": 0.2424463846466758,
|
||||
"loss_toxic": 0.04294342412190004
|
||||
}
|
||||
],
|
||||
"val": []
|
||||
}
|
||||
3
models/nested_cv/20260130_183653/outer_fold_3/model.pt
Normal file
3
models/nested_cv/20260130_183653/outer_fold_3/model.pt
Normal file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:93cd6e31c539b5604fa03b14699fc985cbee10dd8cbbb05f6cebc43f65e394da
|
||||
size 132340732
|
||||
Binary file not shown.
@ -0,0 +1 @@
|
||||
{"outer_train_idx": [0, 1, 2, 3, 7, 8, 12, 13, 14, 15, 17, 19, 20, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 46, 48, 49, 50, 51, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 76, 78, 79, 80, 82, 83, 85, 86, 88, 89, 90, 91, 92, 93, 94, 95, 97, 99, 100, 101, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 127, 128, 129, 130, 132, 133, 134, 136, 138, 140, 143, 144, 145, 146, 147, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 175, 177, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 197, 198, 199, 200, 201, 202, 203, 205, 206, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 220, 222, 225, 226, 227, 228, 229, 231, 232, 233, 235, 236, 237, 238, 239, 240, 241, 242, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 265, 266, 267, 269, 270, 271, 272, 273, 276, 277, 278, 279, 280, 281, 282, 284, 285, 286, 288, 289, 292, 293, 294, 295, 296, 297, 299, 300, 301, 302, 303, 304, 305, 306, 307, 309, 310, 311, 312, 314, 315, 316, 318, 319, 320, 321, 324, 325, 328, 329, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 365, 366, 367, 368, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 383, 384, 385, 386, 387, 389, 390, 391, 393, 394, 396, 397, 398, 401, 402, 405, 406, 407, 408, 409, 410, 411, 413, 414, 415, 416, 417, 418, 420, 421, 423, 424, 425, 427, 428, 430, 431, 432, 433], "outer_test_idx": [4, 5, 6, 9, 10, 11, 16, 18, 22, 23, 24, 39, 40, 45, 47, 52, 70, 75, 77, 81, 84, 87, 96, 98, 102, 114, 126, 131, 135, 137, 139, 141, 142, 148, 155, 172, 176, 178, 179, 196, 204, 207, 219, 221, 223, 224, 230, 234, 243, 244, 264, 268, 274, 275, 283, 287, 290, 291, 298, 308, 313, 317, 322, 323, 326, 327, 330, 331, 332, 349, 360, 364, 369, 370, 382, 388, 392, 395, 399, 400, 403, 404, 412, 419, 422, 426, 429]}
|
||||
@ -0,0 +1,42 @@
|
||||
{
|
||||
"size": {
|
||||
"n_samples": 87,
|
||||
"mse": 0.08391054707837464,
|
||||
"rmse": 0.28967317286620564,
|
||||
"mae": 0.22691357272794876,
|
||||
"r2": 0.26719931627457305
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 62,
|
||||
"mse": 2.0453160934060053,
|
||||
"rmse": 1.4301454798047664,
|
||||
"mae": 0.5443450972558029,
|
||||
"r2": 0.0777919381248442
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6896551724137931,
|
||||
"precision": 0.3994245524296675,
|
||||
"recall": 0.5978021978021978,
|
||||
"f1": 0.417895167895168
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.5862068965517241,
|
||||
"precision": 0.5469462969462969,
|
||||
"recall": 0.5874125874125874,
|
||||
"f1": 0.5375569894616209
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 63,
|
||||
"accuracy": 0.9682539682539683,
|
||||
"precision": 0.8,
|
||||
"recall": 0.9833333333333334,
|
||||
"f1": 0.8665254237288135
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 63,
|
||||
"kl_divergence": 0.28367776789683485,
|
||||
"js_divergence": 0.07318286384043993
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,10 @@
|
||||
{
|
||||
"d_model": 128,
|
||||
"num_heads": 4,
|
||||
"n_attn_layers": 6,
|
||||
"fusion_strategy": "max",
|
||||
"head_hidden_dim": 128,
|
||||
"dropout": 0.15658744776638445,
|
||||
"lr": 0.00031005155898680676,
|
||||
"weight_decay": 5.422040924441196e-05
|
||||
}
|
||||
@ -0,0 +1 @@
|
||||
{"epoch_mean": 14}
|
||||
131
models/nested_cv/20260130_183653/outer_fold_4/history.json
Normal file
131
models/nested_cv/20260130_183653/outer_fold_4/history.json
Normal file
@ -0,0 +1,131 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 15.068541526794434,
|
||||
"loss_size": 9.661331350153143,
|
||||
"loss_pdi": 1.3961028402501887,
|
||||
"loss_ee": 1.1041727607900447,
|
||||
"loss_delivery": 1.131193908778104,
|
||||
"loss_biodist": 1.1299291415648027,
|
||||
"loss_toxic": 0.6458113518628207
|
||||
},
|
||||
{
|
||||
"loss": 5.692212278192693,
|
||||
"loss_size": 0.9680004715919495,
|
||||
"loss_pdi": 1.2202669165351174,
|
||||
"loss_ee": 1.0751874880357222,
|
||||
"loss_delivery": 1.075642000545155,
|
||||
"loss_biodist": 0.8836204951459711,
|
||||
"loss_toxic": 0.46949480880390515
|
||||
},
|
||||
{
|
||||
"loss": 4.2831949970938945,
|
||||
"loss_size": 0.4278720129619945,
|
||||
"loss_pdi": 1.0048133405772122,
|
||||
"loss_ee": 0.9688023762269453,
|
||||
"loss_delivery": 0.9665126570246436,
|
||||
"loss_biodist": 0.649872666055506,
|
||||
"loss_toxic": 0.2653219618580558
|
||||
},
|
||||
{
|
||||
"loss": 3.620699340646917,
|
||||
"loss_size": 0.34346921877427533,
|
||||
"loss_pdi": 0.909794731573625,
|
||||
"loss_ee": 0.8586091019890525,
|
||||
"loss_delivery": 0.8243018510666761,
|
||||
"loss_biodist": 0.5371287275444377,
|
||||
"loss_toxic": 0.14739569039507347
|
||||
},
|
||||
{
|
||||
"loss": 3.2791861187327993,
|
||||
"loss_size": 0.3260936520316384,
|
||||
"loss_pdi": 0.8282042199915106,
|
||||
"loss_ee": 0.8434794707731768,
|
||||
"loss_delivery": 0.7491147694262591,
|
||||
"loss_biodist": 0.4268476908857172,
|
||||
"loss_toxic": 0.10544633001766422
|
||||
},
|
||||
{
|
||||
"loss": 3.000173742120916,
|
||||
"loss_size": 0.3082492568276145,
|
||||
"loss_pdi": 0.6911327242851257,
|
||||
"loss_ee": 0.7926767305894331,
|
||||
"loss_delivery": 0.7263242087580941,
|
||||
"loss_biodist": 0.3646425713192333,
|
||||
"loss_toxic": 0.11714824627746236
|
||||
},
|
||||
{
|
||||
"loss": 2.8449384082447398,
|
||||
"loss_size": 0.3491906361146407,
|
||||
"loss_pdi": 0.7018052475018934,
|
||||
"loss_ee": 0.7752535722472451,
|
||||
"loss_delivery": 0.6207486174323342,
|
||||
"loss_biodist": 0.3087055358019742,
|
||||
"loss_toxic": 0.08923475528982552
|
||||
},
|
||||
{
|
||||
"loss": 2.762920791452581,
|
||||
"loss_size": 0.27234369922767987,
|
||||
"loss_pdi": 0.6848637001080946,
|
||||
"loss_ee": 0.7332382093776356,
|
||||
"loss_delivery": 0.7119220888072794,
|
||||
"loss_biodist": 0.27469146658073773,
|
||||
"loss_toxic": 0.08586158154701645
|
||||
},
|
||||
{
|
||||
"loss": 2.6228579716248945,
|
||||
"loss_size": 0.28220029175281525,
|
||||
"loss_pdi": 0.6514047113331881,
|
||||
"loss_ee": 0.7080714106559753,
|
||||
"loss_delivery": 0.6849517280405218,
|
||||
"loss_biodist": 0.23796464096416126,
|
||||
"loss_toxic": 0.05826510641385208
|
||||
},
|
||||
{
|
||||
"loss": 2.53473145311529,
|
||||
"loss_size": 0.24789857864379883,
|
||||
"loss_pdi": 0.6621171263131228,
|
||||
"loss_ee": 0.7045791528441689,
|
||||
"loss_delivery": 0.6228310723196376,
|
||||
"loss_biodist": 0.24999592927369205,
|
||||
"loss_toxic": 0.047309636138379574
|
||||
},
|
||||
{
|
||||
"loss": 2.439385262402621,
|
||||
"loss_size": 0.2714946825395931,
|
||||
"loss_pdi": 0.6569395525888964,
|
||||
"loss_ee": 0.679518764669245,
|
||||
"loss_delivery": 0.5314246158708226,
|
||||
"loss_biodist": 0.23756521533836017,
|
||||
"loss_toxic": 0.062442497265609825
|
||||
},
|
||||
{
|
||||
"loss": 2.4029004248705776,
|
||||
"loss_size": 0.25774127651344647,
|
||||
"loss_pdi": 0.5974260785362937,
|
||||
"loss_ee": 0.6959952928803184,
|
||||
"loss_delivery": 0.5820360441099514,
|
||||
"loss_biodist": 0.22288588908585635,
|
||||
"loss_toxic": 0.04681587845764377
|
||||
},
|
||||
{
|
||||
"loss": 2.390383416956121,
|
||||
"loss_size": 0.24576269225640732,
|
||||
"loss_pdi": 0.6070916679772463,
|
||||
"loss_ee": 0.6997986544262279,
|
||||
"loss_delivery": 0.5448949479243972,
|
||||
"loss_biodist": 0.2274868596683849,
|
||||
"loss_toxic": 0.065348598936742
|
||||
},
|
||||
{
|
||||
"loss": 2.450988466089422,
|
||||
"loss_size": 0.21830374882979828,
|
||||
"loss_pdi": 0.6097230477766558,
|
||||
"loss_ee": 0.6919564171270891,
|
||||
"loss_delivery": 0.6387598026882518,
|
||||
"loss_biodist": 0.22940932078794998,
|
||||
"loss_toxic": 0.06283610728992657
|
||||
}
|
||||
],
|
||||
"val": []
|
||||
}
|
||||
3
models/nested_cv/20260130_183653/outer_fold_4/model.pt
Normal file
3
models/nested_cv/20260130_183653/outer_fold_4/model.pt
Normal file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cd138dcfe3697ec06e8caa6ec6c35a16635b1eb5a790b07e5c49eaf638dd9379
|
||||
size 11112658
|
||||
Binary file not shown.
@ -0,0 +1 @@
|
||||
{"outer_train_idx": [3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 42, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 60, 62, 63, 64, 66, 67, 68, 69, 70, 72, 73, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 90, 92, 93, 94, 95, 96, 97, 98, 100, 101, 102, 104, 105, 106, 107, 109, 110, 111, 113, 114, 116, 117, 118, 120, 121, 122, 124, 125, 126, 127, 128, 130, 131, 132, 133, 135, 136, 137, 139, 141, 142, 144, 145, 146, 147, 148, 149, 150, 151, 152, 154, 155, 156, 157, 158, 160, 161, 162, 163, 164, 165, 166, 168, 169, 170, 172, 173, 174, 176, 178, 179, 180, 181, 182, 183, 184, 187, 188, 189, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 217, 218, 219, 220, 221, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 238, 240, 241, 242, 243, 244, 245, 246, 247, 248, 250, 251, 252, 253, 254, 256, 257, 258, 259, 261, 263, 264, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 283, 285, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 302, 303, 304, 305, 306, 307, 308, 311, 312, 313, 315, 316, 317, 318, 319, 320, 321, 322, 323, 325, 326, 327, 328, 330, 331, 332, 333, 334, 336, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 349, 350, 351, 354, 355, 356, 357, 358, 360, 361, 362, 364, 367, 368, 369, 370, 371, 373, 374, 375, 376, 377, 378, 379, 382, 383, 385, 386, 388, 390, 391, 392, 393, 394, 395, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 411, 412, 414, 417, 418, 419, 420, 422, 423, 424, 426, 427, 429, 430, 431, 432, 433], "outer_test_idx": [0, 1, 2, 8, 20, 21, 27, 41, 43, 56, 57, 58, 59, 61, 65, 71, 74, 89, 91, 99, 103, 108, 112, 115, 119, 123, 129, 134, 138, 140, 143, 153, 159, 167, 171, 175, 177, 185, 186, 190, 191, 192, 203, 215, 216, 222, 227, 231, 239, 249, 255, 260, 262, 265, 276, 284, 286, 301, 309, 310, 314, 324, 329, 335, 337, 348, 352, 353, 359, 363, 365, 366, 372, 380, 381, 384, 387, 389, 396, 410, 413, 415, 416, 421, 425, 428]}
|
||||
@ -0,0 +1,42 @@
|
||||
{
|
||||
"size": {
|
||||
"n_samples": 86,
|
||||
"mse": 0.17175675509369362,
|
||||
"rmse": 0.41443546553558086,
|
||||
"mae": 0.2695355332174966,
|
||||
"r2": -0.28030669239092476
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 63,
|
||||
"mse": 0.35801504662479033,
|
||||
"rmse": 0.5983435857638906,
|
||||
"mae": 0.4112293718545328,
|
||||
"r2": 0.3391331751101827
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 86,
|
||||
"accuracy": 0.7209302325581395,
|
||||
"precision": 0.5444444444444444,
|
||||
"recall": 0.746031746031746,
|
||||
"f1": 0.5894308943089431
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 86,
|
||||
"accuracy": 0.5930232558139535,
|
||||
"precision": 0.5120650953984287,
|
||||
"recall": 0.5383076043453402,
|
||||
"f1": 0.5184148203294006
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 64,
|
||||
"accuracy": 0.9375,
|
||||
"precision": 0.6666666666666666,
|
||||
"recall": 0.967741935483871,
|
||||
"f1": 0.7333333333333333
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 63,
|
||||
"kl_divergence": 0.2828013605689281,
|
||||
"js_divergence": 0.07105864428768947
|
||||
}
|
||||
}
|
||||
60
models/nested_cv/20260130_183653/strata_info.json
Normal file
60
models/nested_cv/20260130_183653/strata_info.json
Normal file
@ -0,0 +1,60 @@
|
||||
{
|
||||
"original_strata_counts": {
|
||||
"T0|P0|E0": "5",
|
||||
"T0|P0|E1": "63",
|
||||
"T0|P0|E2": "175",
|
||||
"T0|P1|E0": "1",
|
||||
"T0|P1|E1": "19",
|
||||
"T0|P1|E2": "33",
|
||||
"T0|P2|E2": "3",
|
||||
"T1|P0|E2": "9",
|
||||
"T1|P1|E2": "5",
|
||||
"TNA|P0|E0": "28",
|
||||
"TNA|P0|E1": "21",
|
||||
"TNA|P0|E2": "20",
|
||||
"TNA|P1|E0": "29",
|
||||
"TNA|P1|E1": "7",
|
||||
"TNA|P1|E2": "14",
|
||||
"TNA|P2|E2": "1",
|
||||
"TNA|P3|E0": "1"
|
||||
},
|
||||
"rare_strata": [
|
||||
"T0|P1|E0",
|
||||
"T0|P2|E2",
|
||||
"TNA|P2|E2",
|
||||
"TNA|P3|E0"
|
||||
],
|
||||
"final_strata": [
|
||||
"RARE",
|
||||
"T0|P0|E0",
|
||||
"T0|P0|E1",
|
||||
"T0|P0|E2",
|
||||
"T0|P1|E1",
|
||||
"T0|P1|E2",
|
||||
"T1|P0|E2",
|
||||
"T1|P1|E2",
|
||||
"TNA|P0|E0",
|
||||
"TNA|P0|E1",
|
||||
"TNA|P0|E2",
|
||||
"TNA|P1|E0",
|
||||
"TNA|P1|E1",
|
||||
"TNA|P1|E2"
|
||||
],
|
||||
"final_strata_counts": {
|
||||
"RARE": "6",
|
||||
"T0|P0|E0": "5",
|
||||
"T0|P0|E1": "63",
|
||||
"T0|P0|E2": "175",
|
||||
"T0|P1|E1": "19",
|
||||
"T0|P1|E2": "33",
|
||||
"T1|P0|E2": "9",
|
||||
"T1|P1|E2": "5",
|
||||
"TNA|P0|E0": "28",
|
||||
"TNA|P0|E1": "21",
|
||||
"TNA|P0|E2": "20",
|
||||
"TNA|P1|E0": "29",
|
||||
"TNA|P1|E1": "7",
|
||||
"TNA|P1|E2": "14"
|
||||
},
|
||||
"n_rare_merged": "6"
|
||||
}
|
||||
342
models/nested_cv/20260130_183653/summary.json
Normal file
342
models/nested_cv/20260130_183653/summary.json
Normal file
@ -0,0 +1,342 @@
|
||||
{
|
||||
"fold_results": [
|
||||
{
|
||||
"fold": 0,
|
||||
"best_params": {
|
||||
"d_model": 512,
|
||||
"num_heads": 8,
|
||||
"n_attn_layers": 5,
|
||||
"fusion_strategy": "attention",
|
||||
"head_hidden_dim": 128,
|
||||
"dropout": 0.14666736316838325,
|
||||
"lr": 0.0001295888795454003,
|
||||
"weight_decay": 7.732380983243132e-05
|
||||
},
|
||||
"epoch_mean": 13,
|
||||
"test_metrics": {
|
||||
"size": {
|
||||
"n_samples": 87,
|
||||
"mse": 0.26366087871128757,
|
||||
"rmse": 0.5134791901443403,
|
||||
"mae": 0.25157783223294666,
|
||||
"r2": 0.21208517006410577
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 61,
|
||||
"mse": 0.40443344562739025,
|
||||
"rmse": 0.63595082013265,
|
||||
"mae": 0.3928790920429298,
|
||||
"r2": 0.2300258531372983
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.7241379310344828,
|
||||
"precision": 0.35141509433962265,
|
||||
"recall": 0.35351966873706003,
|
||||
"f1": 0.348405985686402
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6666666666666666,
|
||||
"precision": 0.6188811188811189,
|
||||
"recall": 0.6375291375291375,
|
||||
"f1": 0.6217948717948718
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 62,
|
||||
"accuracy": 0.967741935483871,
|
||||
"precision": 0.8,
|
||||
"recall": 0.9830508474576272,
|
||||
"f1": 0.8663793103448275
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 61,
|
||||
"kl_divergence": 0.14776465556036145,
|
||||
"js_divergence": 0.03926150329301917
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"fold": 1,
|
||||
"best_params": {
|
||||
"d_model": 512,
|
||||
"num_heads": 4,
|
||||
"n_attn_layers": 2,
|
||||
"fusion_strategy": "avg",
|
||||
"head_hidden_dim": 64,
|
||||
"dropout": 0.05188345993471756,
|
||||
"lr": 4.21188892865021e-05,
|
||||
"weight_decay": 4.086499445232577e-05
|
||||
},
|
||||
"epoch_mean": 23,
|
||||
"test_metrics": {
|
||||
"size": {
|
||||
"n_samples": 87,
|
||||
"mse": 0.20102046206409732,
|
||||
"rmse": 0.44835305515196094,
|
||||
"mae": 0.2436335881551107,
|
||||
"r2": 0.1900957049146058
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 62,
|
||||
"mse": 0.5899936276993041,
|
||||
"rmse": 0.7681104267612203,
|
||||
"mae": 0.46896366539576484,
|
||||
"r2": 0.4017743543115936
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6666666666666666,
|
||||
"precision": 0.41556437389770723,
|
||||
"recall": 0.4042119565217391,
|
||||
"f1": 0.40777777777777774
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6666666666666666,
|
||||
"precision": 0.6414141414141414,
|
||||
"recall": 0.6782661782661782,
|
||||
"f1": 0.6387485970819304
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 62,
|
||||
"accuracy": 1.0,
|
||||
"precision": 1.0,
|
||||
"recall": 1.0,
|
||||
"f1": 1.0
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 62,
|
||||
"kl_divergence": 0.30758161166563297,
|
||||
"js_divergence": 0.08759221465023677
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"fold": 2,
|
||||
"best_params": {
|
||||
"d_model": 512,
|
||||
"num_heads": 8,
|
||||
"n_attn_layers": 4,
|
||||
"fusion_strategy": "attention",
|
||||
"head_hidden_dim": 64,
|
||||
"dropout": 0.11433976976282646,
|
||||
"lr": 5.3015812445144804e-05,
|
||||
"weight_decay": 6.704431817743382e-06
|
||||
},
|
||||
"epoch_mean": 18,
|
||||
"test_metrics": {
|
||||
"size": {
|
||||
"n_samples": 85,
|
||||
"mse": 0.07855156229299931,
|
||||
"rmse": 0.28027051627490057,
|
||||
"mae": 0.2201253890991211,
|
||||
"r2": 0.009407939412061861
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 62,
|
||||
"mse": 0.4162270771403472,
|
||||
"rmse": 0.645156629928227,
|
||||
"mae": 0.41306305523856635,
|
||||
"r2": 0.37384819758564136
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.7126436781609196,
|
||||
"precision": 0.3963383838383838,
|
||||
"recall": 0.5856060606060606,
|
||||
"f1": 0.43013891331106036
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6781609195402298,
|
||||
"precision": 0.6215366001209922,
|
||||
"recall": 0.65781362712309,
|
||||
"f1": 0.6235867752721687
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 62,
|
||||
"accuracy": 0.967741935483871,
|
||||
"precision": 0.8,
|
||||
"recall": 0.9830508474576272,
|
||||
"f1": 0.8663793103448275
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 62,
|
||||
"kl_divergence": 0.3189032824921648,
|
||||
"js_divergence": 0.07944133611635379
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"fold": 3,
|
||||
"best_params": {
|
||||
"d_model": 512,
|
||||
"num_heads": 4,
|
||||
"n_attn_layers": 5,
|
||||
"fusion_strategy": "attention",
|
||||
"head_hidden_dim": 64,
|
||||
"dropout": 0.11746271741188277,
|
||||
"lr": 0.0001939220403760229,
|
||||
"weight_decay": 2.4722550292920085e-06
|
||||
},
|
||||
"epoch_mean": 10,
|
||||
"test_metrics": {
|
||||
"size": {
|
||||
"n_samples": 87,
|
||||
"mse": 0.08391054707837464,
|
||||
"rmse": 0.28967317286620564,
|
||||
"mae": 0.22691357272794876,
|
||||
"r2": 0.26719931627457305
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 62,
|
||||
"mse": 2.0453160934060053,
|
||||
"rmse": 1.4301454798047664,
|
||||
"mae": 0.5443450972558029,
|
||||
"r2": 0.0777919381248442
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.6896551724137931,
|
||||
"precision": 0.3994245524296675,
|
||||
"recall": 0.5978021978021978,
|
||||
"f1": 0.417895167895168
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 87,
|
||||
"accuracy": 0.5862068965517241,
|
||||
"precision": 0.5469462969462969,
|
||||
"recall": 0.5874125874125874,
|
||||
"f1": 0.5375569894616209
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 63,
|
||||
"accuracy": 0.9682539682539683,
|
||||
"precision": 0.8,
|
||||
"recall": 0.9833333333333334,
|
||||
"f1": 0.8665254237288135
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 63,
|
||||
"kl_divergence": 0.28367776789683485,
|
||||
"js_divergence": 0.07318286384043993
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"fold": 4,
|
||||
"best_params": {
|
||||
"d_model": 128,
|
||||
"num_heads": 4,
|
||||
"n_attn_layers": 6,
|
||||
"fusion_strategy": "max",
|
||||
"head_hidden_dim": 128,
|
||||
"dropout": 0.15658744776638445,
|
||||
"lr": 0.00031005155898680676,
|
||||
"weight_decay": 5.422040924441196e-05
|
||||
},
|
||||
"epoch_mean": 14,
|
||||
"test_metrics": {
|
||||
"size": {
|
||||
"n_samples": 86,
|
||||
"mse": 0.17175675509369362,
|
||||
"rmse": 0.41443546553558086,
|
||||
"mae": 0.2695355332174966,
|
||||
"r2": -0.28030669239092476
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 63,
|
||||
"mse": 0.35801504662479033,
|
||||
"rmse": 0.5983435857638906,
|
||||
"mae": 0.4112293718545328,
|
||||
"r2": 0.3391331751101827
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 86,
|
||||
"accuracy": 0.7209302325581395,
|
||||
"precision": 0.5444444444444444,
|
||||
"recall": 0.746031746031746,
|
||||
"f1": 0.5894308943089431
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 86,
|
||||
"accuracy": 0.5930232558139535,
|
||||
"precision": 0.5120650953984287,
|
||||
"recall": 0.5383076043453402,
|
||||
"f1": 0.5184148203294006
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 64,
|
||||
"accuracy": 0.9375,
|
||||
"precision": 0.6666666666666666,
|
||||
"recall": 0.967741935483871,
|
||||
"f1": 0.7333333333333333
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 63,
|
||||
"kl_divergence": 0.2828013605689281,
|
||||
"js_divergence": 0.07105864428768947
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"summary_stats": {
|
||||
"size": {
|
||||
"mse_mean": 0.1597800410480905,
|
||||
"mse_std": 0.07069609368925868,
|
||||
"rmse_mean": 0.3892422799945977,
|
||||
"rmse_std": 0.09094222623565874,
|
||||
"mae_mean": 0.24235718308652476,
|
||||
"mae_std": 0.017652592224532318,
|
||||
"r2_mean": 0.07969628765488435,
|
||||
"r2_std": 0.19970720107145462
|
||||
},
|
||||
"delivery": {
|
||||
"mse_mean": 0.7627970580995674,
|
||||
"mse_std": 0.6460804607386303,
|
||||
"rmse_mean": 0.8155413884781508,
|
||||
"rmse_std": 0.31255287837212004,
|
||||
"mae_mean": 0.44609605635751937,
|
||||
"mae_std": 0.055343855526290356,
|
||||
"r2_mean": 0.28451470365391207,
|
||||
"r2_std": 0.11867334397685028
|
||||
},
|
||||
"pdi": {
|
||||
"accuracy_mean": 0.7028067361668003,
|
||||
"accuracy_std": 0.021722406295571848,
|
||||
"precision_mean": 0.4214373697899651,
|
||||
"precision_std": 0.06508897722192637,
|
||||
"recall_mean": 0.5374343259397607,
|
||||
"recall_std": 0.1421622175476529,
|
||||
"f1_mean": 0.4387297477958702,
|
||||
"f1_std": 0.0804178141542625
|
||||
},
|
||||
"ee": {
|
||||
"accuracy_mean": 0.6381448810478482,
|
||||
"accuracy_std": 0.039904343482039924,
|
||||
"precision_mean": 0.5881686505521957,
|
||||
"precision_std": 0.049765031116764336,
|
||||
"recall_mean": 0.6198658269352666,
|
||||
"recall_std": 0.05072984435881663,
|
||||
"f1_mean": 0.5880204107879985,
|
||||
"f1_std": 0.049740374945613494
|
||||
},
|
||||
"toxic": {
|
||||
"accuracy_mean": 0.9682475678443421,
|
||||
"accuracy_std": 0.01976937654694873,
|
||||
"precision_mean": 0.8133333333333335,
|
||||
"precision_std": 0.10666666666666666,
|
||||
"recall_mean": 0.9834353927464917,
|
||||
"recall_std": 0.010207614614579376,
|
||||
"f1_mean": 0.8665234755503602,
|
||||
"f1_std": 0.08432750219703594
|
||||
},
|
||||
"biodist": {
|
||||
"kl_divergence_mean": 0.26814573563678445,
|
||||
"kl_divergence_std": 0.06177240919631341,
|
||||
"js_divergence_mean": 0.07010731243754784,
|
||||
"js_divergence_std": 0.016460095953094674
|
||||
}
|
||||
}
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
211
pixi.lock
211
pixi.lock
@ -53,6 +53,7 @@ environments:
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/xz-tools-5.8.1-hb9d3cd8_2.conda
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda
|
||||
- pypi: https://files.pythonhosted.org/packages/64/88/c7083fc61120ab661c5d0b82cb77079fc1429d3f913a456c1c82cf4658f7/alabaster-0.7.13-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/54/7e/ac0991d1745f7d755fc1cd381b3990a45b404b4d008fc75e2a983516fbfe/alembic-1.14.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/9b/52/4a86a4fa1cc2aae79137cc9510b7080c3e5aede2310d14fae5486feec7f7/altair-5.4.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
|
||||
@ -65,6 +66,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/bf/fa/cf5bb2409a385f78750e78c8d2e24780964976acdaaed65dbd6083ae5b40/charset_normalizer-3.4.4-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8e/71/7f20855592cc929bc206810432b991ec4c702dc26b0567b132e52c85536f/contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/15/1a/c6eae628480aa1fc5f6f85437c7d8ec0d1172597acd1c61182202a902c0f/cramjam-2.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl
|
||||
@ -82,6 +84,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/da/71/ae30dadffc90b9006d77af76b393cb9dfbfc9629f339fc1574a1c52e6806/future-1.0.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/d8/88/0ce16c0afb2d71d85562a7bcd9b092fec80a7767ab5b5f7e1bbbca8200f8/greenlet-3.1.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl
|
||||
@ -96,6 +99,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ee/07/44bd408781594c4d0a027666ef27fab1e441b109dc3b76b4f836f8fd04fe/jsonschema_specifications-2023.12.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/76/36/ae40d7a3171e06f55ac77fe5536079e7be1d8be2a8210e08975c7f9b4d54/kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c7/bd/50319665ce81bb10e90d1cf76f9e1aa269ea6f7fa30ab4521f14d122a3df/MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/30/33/cc27211d2ffeee4fd7402dca137b6e8a83f6dcae3d4be8d0ad5068555561/matplotlib-3.7.5-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl
|
||||
@ -116,6 +120,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/46/0c/c75bbfb967457a0b7670b8ad267bfc4fffdf341c074e0a80db06c24ccfd4/nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7f/12/cba81286cbaf0f0c3f0473846cfd992cb240bdcea816bf2ef7de8ed0f744/optuna-4.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/f8/7f/5b047effafbdd34e52c9e2d7e44f729a0655efafb22198c45cf692cdc157/pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/5d/e6/71ed4d95676098159b533c4a4c424cf453fec9614edaff1a0633fe228eef/pandas_flavor-0.7.0-py3-none-any.whl
|
||||
@ -131,6 +136,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/25/a2/b725b61ac76a75583ae7104b3209f75ea44b13cfd026aa535ece22b7f22e/PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/3d/84/63b2e66f5c7cb97ce994769afbbef85a1ac364fedbcb7d4a3c0f15d318a5/rdkit-2024.3.5-cp38-cp38-manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b7/59/2056f61236782a2c86b33906c025d4f4a0b17be0161b63b70fd9e8775d36/referencing-0.35.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl
|
||||
@ -150,6 +156,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2b/14/05f9206cf4e9cfca1afb5fd224c7cd434dcc3a433d6d9e4e0264d29c6cdb/sphinxcontrib_qthelp-1.0.3-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c6/77/5464ec50dd0f1c1037e3c93249b040c8fc8078fdda97530eeb02424b6eea/sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/8a/7b/b9e0eb7f9a15f2e82856603c728edf14c54a07c6738ab228e4f2de049338/sqlalchemy-2.0.46-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b6/c5/7ae467eeddb57260c8ce17a3a09f9f5edba35820fc022d7c55b7decd5d3a/starlette-0.44.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/9a/14/857d0734989f3d26f2f965b2e3f67568ea7a6e8a60cb9c1ed7f774b6d606/streamlit-1.40.1-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl
|
||||
@ -206,6 +213,7 @@ environments:
|
||||
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/xz-gpl-tools-5.8.1-h9a6d368_2.conda
|
||||
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/xz-tools-5.8.1-h39f12f2_2.conda
|
||||
- pypi: https://files.pythonhosted.org/packages/64/88/c7083fc61120ab661c5d0b82cb77079fc1429d3f913a456c1c82cf4658f7/alabaster-0.7.13-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/54/7e/ac0991d1745f7d755fc1cd381b3990a45b404b4d008fc75e2a983516fbfe/alembic-1.14.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/9b/52/4a86a4fa1cc2aae79137cc9510b7080c3e5aede2310d14fae5486feec7f7/altair-5.4.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
|
||||
@ -218,6 +226,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/0a/4e/3926a1c11f0433791985727965263f788af00db3482d89a7545ca5ecc921/charset_normalizer-3.4.4-cp38-cp38-macosx_10_9_universal2.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a6/82/29f5ff4ae074c3230e266bc9efef449ebde43721a727b989dd8ef8f97d73/contourpy-1.1.1-cp38-cp38-macosx_11_0_arm64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/65/64/e34ee535519fd14cde3a7f3f8cd3b4ef54483b9df655e4180437eb884aab/cramjam-2.11.0-cp38-cp38-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl
|
||||
@ -249,6 +258,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/ee/07/44bd408781594c4d0a027666ef27fab1e441b109dc3b76b4f836f8fd04fe/jsonschema_specifications-2023.12.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/14/a7/bb8ab10e12cc8764f4da0245d72dee4731cc720bdec0f085d5e9c6005b98/kiwisolver-1.4.7-cp38-cp38-macosx_11_0_arm64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/f8/ff/2c942a82c35a49df5de3a630ce0a8456ac2969691b230e530ac12314364c/MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/aa/59/4d13e5b6298b1ca5525eea8c68d3806ae93ab6d0bb17ca9846aa3156b92b/matplotlib-3.7.5-cp38-cp38-macosx_11_0_arm64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl
|
||||
@ -257,6 +267,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/a8/05/9d4f9b78ead6b2661d6e8ea772e111fc4a9fbd866ad0c81906c11206b55e/networkx-3.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/a7/ae/f53b7b265fdc701e663fbb322a8e9d4b14d9cb7b2385f45ddfabfc4327e4/numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7f/12/cba81286cbaf0f0c3f0473846cfd992cb240bdcea816bf2ef7de8ed0f744/optuna-4.5.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/53/c3/f8e87361f7fdf42012def602bfa2a593423c729f5cb7c97aed7f51be66ac/pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/5d/e6/71ed4d95676098159b533c4a4c424cf453fec9614edaff1a0633fe228eef/pandas_flavor-0.7.0-py3-none-any.whl
|
||||
@ -272,6 +283,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz
|
||||
- pypi: https://files.pythonhosted.org/packages/bf/cb/c709b60f4815e18c00e1e8639204bdba04cb158e6278791d82f94f51a988/rdkit-2024.3.5-cp38-cp38-macosx_11_0_arm64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b7/59/2056f61236782a2c86b33906c025d4f4a0b17be0161b63b70fd9e8775d36/referencing-0.35.1-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl
|
||||
@ -291,6 +303,7 @@ environments:
|
||||
- pypi: https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/2b/14/05f9206cf4e9cfca1afb5fd224c7cd434dcc3a433d6d9e4e0264d29c6cdb/sphinxcontrib_qthelp-1.0.3-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/c6/77/5464ec50dd0f1c1037e3c93249b040c8fc8078fdda97530eeb02424b6eea/sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/07/d4/76b9618d7eb1e6a3c26734e1186f8ad7869e4426b1ea7dc425bc4c832e67/sqlalchemy-2.0.46-cp38-cp38-macosx_11_0_arm64.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/b6/c5/7ae467eeddb57260c8ce17a3a09f9f5edba35820fc022d7c55b7decd5d3a/starlette-0.44.0-py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/9a/14/857d0734989f3d26f2f965b2e3f67568ea7a6e8a60cb9c1ed7f774b6d606/streamlit-1.40.1-py2.py3-none-any.whl
|
||||
- pypi: https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl
|
||||
@ -336,6 +349,19 @@ packages:
|
||||
version: 0.7.13
|
||||
sha256: 1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3
|
||||
requires_python: '>=3.6'
|
||||
- pypi: https://files.pythonhosted.org/packages/54/7e/ac0991d1745f7d755fc1cd381b3990a45b404b4d008fc75e2a983516fbfe/alembic-1.14.1-py3-none-any.whl
|
||||
name: alembic
|
||||
version: 1.14.1
|
||||
sha256: 1acdd7a3a478e208b0503cd73614d5e4c6efafa4e73518bb60e4f2846a37b1c5
|
||||
requires_dist:
|
||||
- sqlalchemy>=1.3.0
|
||||
- mako
|
||||
- importlib-metadata ; python_full_version < '3.9'
|
||||
- importlib-resources ; python_full_version < '3.9'
|
||||
- typing-extensions>=4
|
||||
- backports-zoneinfo ; python_full_version < '3.9' and extra == 'tz'
|
||||
- tzdata ; extra == 'tz'
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/9b/52/4a86a4fa1cc2aae79137cc9510b7080c3e5aede2310d14fae5486feec7f7/altair-5.4.1-py3-none-any.whl
|
||||
name: altair
|
||||
version: 5.4.1
|
||||
@ -589,6 +615,18 @@ packages:
|
||||
- pkg:pypi/colorama?source=hash-mapping
|
||||
size: 25170
|
||||
timestamp: 1666700778190
|
||||
- pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl
|
||||
name: colorlog
|
||||
version: 6.10.1
|
||||
sha256: 2d7e8348291948af66122cff006c9f8da6255d224e7cf8e37d8de2df3bad8c9c
|
||||
requires_dist:
|
||||
- colorama ; sys_platform == 'win32'
|
||||
- black ; extra == 'development'
|
||||
- flake8 ; extra == 'development'
|
||||
- mypy ; extra == 'development'
|
||||
- pytest ; extra == 'development'
|
||||
- types-colorama ; extra == 'development'
|
||||
requires_python: '>=3.6'
|
||||
- pypi: https://files.pythonhosted.org/packages/8e/71/7f20855592cc929bc206810432b991ec4c702dc26b0567b132e52c85536f/contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
name: contourpy
|
||||
version: 1.1.1
|
||||
@ -1013,6 +1051,16 @@ packages:
|
||||
- sphinx-rtd-theme ; extra == 'doc'
|
||||
- sphinx-autodoc-typehints ; extra == 'doc'
|
||||
requires_python: '>=3.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/d8/88/0ce16c0afb2d71d85562a7bcd9b092fec80a7767ab5b5f7e1bbbca8200f8/greenlet-3.1.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||
name: greenlet
|
||||
version: 3.1.1
|
||||
sha256: 85f3ff71e2e60bd4b4932a043fbbe0f499e263c628390b285cb599154a3b03b1
|
||||
requires_dist:
|
||||
- sphinx ; extra == 'docs'
|
||||
- furo ; extra == 'docs'
|
||||
- objgraph ; extra == 'test'
|
||||
- psutil ; extra == 'test'
|
||||
requires_python: '>=3.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl
|
||||
name: h11
|
||||
version: 0.16.0
|
||||
@ -1453,6 +1501,16 @@ packages:
|
||||
- pkg:pypi/loguru?source=hash-mapping
|
||||
size: 97617
|
||||
timestamp: 1695547715271
|
||||
- pypi: https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl
|
||||
name: mako
|
||||
version: 1.3.10
|
||||
sha256: baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59
|
||||
requires_dist:
|
||||
- markupsafe>=0.9.2
|
||||
- pytest ; extra == 'testing'
|
||||
- babel ; extra == 'babel'
|
||||
- lingua ; extra == 'lingua'
|
||||
requires_python: '>=3.8'
|
||||
- conda: https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda
|
||||
sha256: c041b0eaf7a6af3344d5dd452815cdc148d6284fec25a4fa3f4263b3a021e962
|
||||
md5: 93a8e71256479c62074356ef6ebf501b
|
||||
@ -1709,6 +1767,73 @@ packages:
|
||||
purls: []
|
||||
size: 3108371
|
||||
timestamp: 1762839712322
|
||||
- pypi: https://files.pythonhosted.org/packages/7f/12/cba81286cbaf0f0c3f0473846cfd992cb240bdcea816bf2ef7de8ed0f744/optuna-4.5.0-py3-none-any.whl
|
||||
name: optuna
|
||||
version: 4.5.0
|
||||
sha256: 5b8a783e84e448b0742501bc27195344a28d2c77bd2feef5b558544d954851b0
|
||||
requires_dist:
|
||||
- alembic>=1.5.0
|
||||
- colorlog
|
||||
- numpy
|
||||
- packaging>=20.0
|
||||
- sqlalchemy>=1.4.2
|
||||
- tqdm
|
||||
- pyyaml
|
||||
- asv>=0.5.0 ; extra == 'benchmark'
|
||||
- cma ; extra == 'benchmark'
|
||||
- virtualenv ; extra == 'benchmark'
|
||||
- black ; extra == 'checking'
|
||||
- blackdoc ; extra == 'checking'
|
||||
- flake8 ; extra == 'checking'
|
||||
- isort ; extra == 'checking'
|
||||
- mypy ; extra == 'checking'
|
||||
- mypy-boto3-s3 ; extra == 'checking'
|
||||
- scipy-stubs ; python_full_version >= '3.10' and extra == 'checking'
|
||||
- types-pyyaml ; extra == 'checking'
|
||||
- types-redis ; extra == 'checking'
|
||||
- types-setuptools ; extra == 'checking'
|
||||
- types-tqdm ; extra == 'checking'
|
||||
- typing-extensions>=3.10.0.0 ; extra == 'checking'
|
||||
- ase ; extra == 'document'
|
||||
- cmaes>=0.12.0 ; extra == 'document'
|
||||
- fvcore ; extra == 'document'
|
||||
- kaleido<0.4 ; extra == 'document'
|
||||
- lightgbm ; extra == 'document'
|
||||
- matplotlib!=3.6.0 ; extra == 'document'
|
||||
- pandas ; extra == 'document'
|
||||
- pillow ; extra == 'document'
|
||||
- plotly>=4.9.0 ; extra == 'document'
|
||||
- scikit-learn ; extra == 'document'
|
||||
- sphinx ; extra == 'document'
|
||||
- sphinx-copybutton ; extra == 'document'
|
||||
- sphinx-gallery ; extra == 'document'
|
||||
- sphinx-notfound-page ; extra == 'document'
|
||||
- sphinx-rtd-theme>=1.2.0 ; extra == 'document'
|
||||
- torch ; extra == 'document'
|
||||
- torchvision ; extra == 'document'
|
||||
- boto3 ; extra == 'optional'
|
||||
- cmaes>=0.12.0 ; extra == 'optional'
|
||||
- google-cloud-storage ; extra == 'optional'
|
||||
- matplotlib!=3.6.0 ; extra == 'optional'
|
||||
- pandas ; extra == 'optional'
|
||||
- plotly>=4.9.0 ; extra == 'optional'
|
||||
- redis ; extra == 'optional'
|
||||
- scikit-learn>=0.24.2 ; extra == 'optional'
|
||||
- scipy ; extra == 'optional'
|
||||
- torch ; extra == 'optional'
|
||||
- grpcio ; extra == 'optional'
|
||||
- protobuf>=5.28.1 ; extra == 'optional'
|
||||
- coverage ; extra == 'test'
|
||||
- fakeredis[lua] ; extra == 'test'
|
||||
- kaleido<0.4 ; extra == 'test'
|
||||
- moto ; extra == 'test'
|
||||
- pytest ; extra == 'test'
|
||||
- pytest-xdist ; extra == 'test'
|
||||
- scipy>=1.9.2 ; extra == 'test'
|
||||
- torch ; extra == 'test'
|
||||
- grpcio ; extra == 'test'
|
||||
- protobuf>=5.28.1 ; extra == 'test'
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl
|
||||
name: packaging
|
||||
version: '24.2'
|
||||
@ -2136,6 +2261,16 @@ packages:
|
||||
name: pytz
|
||||
version: '2025.2'
|
||||
sha256: 5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00
|
||||
- pypi: https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz
|
||||
name: pyyaml
|
||||
version: 6.0.3
|
||||
sha256: d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/25/a2/b725b61ac76a75583ae7104b3209f75ea44b13cfd026aa535ece22b7f22e/PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
name: pyyaml
|
||||
version: 6.0.3
|
||||
sha256: 22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6
|
||||
requires_python: '>=3.8'
|
||||
- pypi: https://files.pythonhosted.org/packages/3d/84/63b2e66f5c7cb97ce994769afbbef85a1ac364fedbcb7d4a3c0f15d318a5/rdkit-2024.3.5-cp38-cp38-manylinux_2_28_x86_64.whl
|
||||
name: rdkit
|
||||
version: 2024.3.5
|
||||
@ -2554,6 +2689,82 @@ packages:
|
||||
- docutils-stubs ; extra == 'lint'
|
||||
- pytest ; extra == 'test'
|
||||
requires_python: '>=3.5'
|
||||
- pypi: https://files.pythonhosted.org/packages/07/d4/76b9618d7eb1e6a3c26734e1186f8ad7869e4426b1ea7dc425bc4c832e67/sqlalchemy-2.0.46-cp38-cp38-macosx_11_0_arm64.whl
|
||||
name: sqlalchemy
|
||||
version: 2.0.46
|
||||
sha256: 6ac245604295b521de49b465bab845e3afe6916bcb2147e5929c8041b4ec0545
|
||||
requires_dist:
|
||||
- typing-extensions>=4.6.0
|
||||
- greenlet>=1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||
- importlib-metadata ; python_full_version < '3.8'
|
||||
- greenlet>=1 ; extra == 'aiomysql'
|
||||
- aiomysql>=0.2.0 ; extra == 'aiomysql'
|
||||
- greenlet>=1 ; extra == 'aioodbc'
|
||||
- aioodbc ; extra == 'aioodbc'
|
||||
- greenlet>=1 ; extra == 'aiosqlite'
|
||||
- aiosqlite ; extra == 'aiosqlite'
|
||||
- typing-extensions!=3.10.0.1 ; extra == 'aiosqlite'
|
||||
- greenlet>=1 ; extra == 'asyncio'
|
||||
- greenlet>=1 ; extra == 'asyncmy'
|
||||
- asyncmy>=0.2.3,!=0.2.4,!=0.2.6 ; extra == 'asyncmy'
|
||||
- mariadb>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10 ; extra == 'mariadb-connector'
|
||||
- pyodbc ; extra == 'mssql'
|
||||
- pymssql ; extra == 'mssql-pymssql'
|
||||
- pyodbc ; extra == 'mssql-pyodbc'
|
||||
- mypy>=0.910 ; extra == 'mypy'
|
||||
- mysqlclient>=1.4.0 ; extra == 'mysql'
|
||||
- mysql-connector-python ; extra == 'mysql-connector'
|
||||
- cx-oracle>=8 ; extra == 'oracle'
|
||||
- oracledb>=1.0.1 ; extra == 'oracle-oracledb'
|
||||
- psycopg2>=2.7 ; extra == 'postgresql'
|
||||
- greenlet>=1 ; extra == 'postgresql-asyncpg'
|
||||
- asyncpg ; extra == 'postgresql-asyncpg'
|
||||
- pg8000>=1.29.1 ; extra == 'postgresql-pg8000'
|
||||
- psycopg>=3.0.7 ; extra == 'postgresql-psycopg'
|
||||
- psycopg2-binary ; extra == 'postgresql-psycopg2binary'
|
||||
- psycopg2cffi ; extra == 'postgresql-psycopg2cffi'
|
||||
- psycopg[binary]>=3.0.7 ; extra == 'postgresql-psycopgbinary'
|
||||
- pymysql ; extra == 'pymysql'
|
||||
- sqlcipher3-binary ; extra == 'sqlcipher'
|
||||
requires_python: '>=3.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/8a/7b/b9e0eb7f9a15f2e82856603c728edf14c54a07c6738ab228e4f2de049338/sqlalchemy-2.0.46-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
name: sqlalchemy
|
||||
version: 2.0.46
|
||||
sha256: 716be5bcabf327b6d5d265dbdc6213a01199be587224eb991ad0d37e83d728fd
|
||||
requires_dist:
|
||||
- typing-extensions>=4.6.0
|
||||
- greenlet>=1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||
- importlib-metadata ; python_full_version < '3.8'
|
||||
- greenlet>=1 ; extra == 'aiomysql'
|
||||
- aiomysql>=0.2.0 ; extra == 'aiomysql'
|
||||
- greenlet>=1 ; extra == 'aioodbc'
|
||||
- aioodbc ; extra == 'aioodbc'
|
||||
- greenlet>=1 ; extra == 'aiosqlite'
|
||||
- aiosqlite ; extra == 'aiosqlite'
|
||||
- typing-extensions!=3.10.0.1 ; extra == 'aiosqlite'
|
||||
- greenlet>=1 ; extra == 'asyncio'
|
||||
- greenlet>=1 ; extra == 'asyncmy'
|
||||
- asyncmy>=0.2.3,!=0.2.4,!=0.2.6 ; extra == 'asyncmy'
|
||||
- mariadb>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10 ; extra == 'mariadb-connector'
|
||||
- pyodbc ; extra == 'mssql'
|
||||
- pymssql ; extra == 'mssql-pymssql'
|
||||
- pyodbc ; extra == 'mssql-pyodbc'
|
||||
- mypy>=0.910 ; extra == 'mypy'
|
||||
- mysqlclient>=1.4.0 ; extra == 'mysql'
|
||||
- mysql-connector-python ; extra == 'mysql-connector'
|
||||
- cx-oracle>=8 ; extra == 'oracle'
|
||||
- oracledb>=1.0.1 ; extra == 'oracle-oracledb'
|
||||
- psycopg2>=2.7 ; extra == 'postgresql'
|
||||
- greenlet>=1 ; extra == 'postgresql-asyncpg'
|
||||
- asyncpg ; extra == 'postgresql-asyncpg'
|
||||
- pg8000>=1.29.1 ; extra == 'postgresql-pg8000'
|
||||
- psycopg>=3.0.7 ; extra == 'postgresql-psycopg'
|
||||
- psycopg2-binary ; extra == 'postgresql-psycopg2binary'
|
||||
- psycopg2cffi ; extra == 'postgresql-psycopg2cffi'
|
||||
- psycopg[binary]>=3.0.7 ; extra == 'postgresql-psycopgbinary'
|
||||
- pymysql ; extra == 'pymysql'
|
||||
- sqlcipher3-binary ; extra == 'sqlcipher'
|
||||
requires_python: '>=3.7'
|
||||
- pypi: https://files.pythonhosted.org/packages/b6/c5/7ae467eeddb57260c8ce17a3a09f9f5edba35820fc022d7c55b7decd5d3a/starlette-0.44.0-py3-none-any.whl
|
||||
name: starlette
|
||||
version: 0.44.0
|
||||
|
||||
@ -28,3 +28,4 @@ fastapi = ">=0.124.4, <0.125"
|
||||
streamlit = ">=1.40.1, <2"
|
||||
httpx = ">=0.28.1, <0.29"
|
||||
uvicorn = ">=0.33.0, <0.34"
|
||||
optuna = ">=4.5.0, <5"
|
||||
|
||||
21
requirements.txt
Normal file
21
requirements.txt
Normal file
@ -0,0 +1,21 @@
|
||||
# 严格遵循 pixi.toml 的依赖版本
|
||||
# 注意: lnp_ml 本地包在 Dockerfile 中单独安装
|
||||
|
||||
# conda dependencies (in pixi.toml [dependencies])
|
||||
loguru
|
||||
tqdm
|
||||
typer
|
||||
|
||||
# pypi dependencies (in pixi.toml [pypi-dependencies])
|
||||
chemprop==1.7.0
|
||||
setuptools
|
||||
pandas>=2.0.3,<3
|
||||
openpyxl>=3.1.5,<4
|
||||
python-dotenv>=1.0.1,<2
|
||||
pyarrow>=17.0.0,<18
|
||||
fastparquet>=2024.2.0,<2025
|
||||
fastapi>=0.124.4,<0.125
|
||||
streamlit>=1.40.1,<2
|
||||
httpx>=0.28.1,<0.29
|
||||
uvicorn>=0.33.0,<0.34
|
||||
optuna>=4.5.0,<5
|
||||
Loading…
x
Reference in New Issue
Block a user