Update models and UI

This commit is contained in:
RYDE-WORK 2026-02-11 16:49:28 +08:00
parent 3f33f9d233
commit a9392aa780
75 changed files with 4937 additions and 193 deletions

63
Dockerfile Normal file
View File

@ -0,0 +1,63 @@
# LNP-ML Docker Image
# 多阶段构建,支持 API 和 Streamlit 两种服务
FROM python:3.8-slim AS base
# 设置环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
libxrender1 \
libxext6 \
curl \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装 Python 依赖
RUN pip install --upgrade pip && \
pip install -r requirements.txt
# 复制项目代码
COPY pyproject.toml .
COPY README.md .
COPY LICENSE .
COPY lnp_ml/ ./lnp_ml/
COPY app/ ./app/
# 安装项目包
RUN pip install -e .
# 复制模型文件
COPY models/final/ ./models/final/
# ============ API 服务 ============
FROM base AS api
EXPOSE 8000
ENV MODEL_PATH=/app/models/final/model.pt
CMD ["uvicorn", "app.api:app", "--host", "0.0.0.0", "--port", "8000"]
# ============ Streamlit 服务 ============
FROM base AS streamlit
EXPOSE 8501
# Streamlit 配置
ENV STREAMLIT_SERVER_PORT=8501 \
STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
STREAMLIT_SERVER_HEADLESS=true \
STREAMLIT_BROWSER_GATHER_USAGE_STATS=false
CMD ["streamlit", "run", "app/app.py"]

View File

@ -164,6 +164,44 @@ test_cv: requirements
tune: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --tune $(MPNN_FLAG) $(DEVICE_FLAG)
# ============ 嵌套 CV + Optuna 调参StratifiedKFold + 类权重) ============
# 通用参数:
# SEED: 随机种子 (默认: 42)
# N_TRIALS: Optuna 试验数 (默认: 20)
# EPOCHS_PER_TRIAL: 每个试验的最大 epoch (默认: 30)
# MIN_STRATUM_COUNT: 复合分层标签的最小样本数 (默认: 5)
# OUTPUT_DIR: 输出目录 (根据命令有不同默认值)
# INIT_PRETRAIN: 预训练权重路径 (默认: models/pretrain_delivery.pt)
SEED_FLAG = $(if $(SEED),--seed $(SEED),)
N_TRIALS_FLAG = $(if $(N_TRIALS),--n-trials $(N_TRIALS),)
EPOCHS_PER_TRIAL_FLAG = $(if $(EPOCHS_PER_TRIAL),--epochs-per-trial $(EPOCHS_PER_TRIAL),)
MIN_STRATUM_FLAG = $(if $(MIN_STRATUM_COUNT),--min-stratum-count $(MIN_STRATUM_COUNT),)
OUTPUT_DIR_FLAG = $(if $(OUTPUT_DIR),--output-dir $(OUTPUT_DIR),)
USE_SWA_FLAG = $(if $(USE_SWA),--use-swa,)
# 默认使用预训练权重,设置 NO_PRETRAIN=1 可禁用
INIT_PRETRAIN_FLAG = $(if $(NO_PRETRAIN),,--init-from-pretrain $(or $(INIT_PRETRAIN),models/pretrain_delivery.pt))
## Nested CV with Optuna: outer 5-fold (test) + inner 3-fold (tune)
## 用于模型评估:外层 5-fold 产生无偏性能估计,内层 3-fold 做超参搜索
## 默认加载 models/pretrain_delivery.pt 预训练权重,使用 NO_PRETRAIN=1 禁用
## 使用示例: make nested_cv_tune DEVICE=cuda N_TRIALS=30
.PHONY: nested_cv_tune
nested_cv_tune: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.nested_cv_optuna \
$(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG)
## Final training with Optuna: 3-fold CV tune + full data train
## 用于最终模型训练3-fold 调参后用全量数据训练(无 early-stop
## 默认加载 models/pretrain_delivery.pt 预训练权重,使用 NO_PRETRAIN=1 禁用
## 使用示例: make final_optuna DEVICE=cuda N_TRIALS=30 USE_SWA=1
.PHONY: final_optuna
final_optuna: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.final_train_optuna_cv \
$(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG)
## Run predictions
.PHONY: predict
predict: requirements
@ -200,6 +238,48 @@ serve:
@echo "然后访问: http://localhost:8501"
#################################################################################
# DOCKER COMMANDS #
#################################################################################
## Build Docker images
.PHONY: docker-build
docker-build:
docker compose build
## Start all services with Docker Compose
.PHONY: docker-up
docker-up:
docker compose up -d
## Stop all Docker services
.PHONY: docker-down
docker-down:
docker compose down
## View Docker logs
.PHONY: docker-logs
docker-logs:
docker compose logs -f
## Build and start all services
.PHONY: docker-serve
docker-serve: docker-build docker-up
@echo ""
@echo "🚀 服务已启动!"
@echo " - API: http://localhost:8000"
@echo " - Web 应用: http://localhost:8501"
@echo ""
@echo "查看日志: make docker-logs"
@echo "停止服务: make docker-down"
## Clean Docker resources (images, volumes, etc.)
.PHONY: docker-clean
docker-clean:
docker compose down -v --rmi local
docker system prune -f
#################################################################################
# Self Documenting Commands #
#################################################################################

15
app/SCORE.md Normal file
View File

@ -0,0 +1,15 @@
## regression
biodistribution(selected organ only): score = y * weight, where weight=0.3
quantified_delivery: score = (y-min)/(max-min)*weight, where weight=0.25, (min=-0.798559291, max=4.497814051056962) when route_of_administration=intravenous, (min=-0.794912427, max=10.220042980012716) when route_of_administration=intramuscular
size: score = 0 * weight if y<60, 1 * weight if 60<=y<=150, 0 * weight if y>150, where weight=0.05
## classification
encapsulation_efficiency_0: score = weight, where weight=0
encapsulation_efficiency_1: score = weight, where weight=0.02
encapsulation_efficiency_2: score = weight, where weight=0.08
pdi_0: score = weight, where weight=0.08
pdi_1: score = weight, where weight=0.02
pdi_2: score = weight, where weight=0
pdi_3: score = weight, where weight=0
toxicity_0: score=weight, where weight=0.2
toxicity_1: score=weight, where weight=0

View File

@ -23,23 +23,87 @@ from app.optimize import (
format_results,
AVAILABLE_ORGANS,
TARGET_BIODIST,
CompRanges,
ScoringWeights,
)
# ============ Pydantic Models ============
class CompRangesRequest(BaseModel):
"""组分范围配置"""
weight_ratio_min: float = Field(default=0.05, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最小值")
weight_ratio_max: float = Field(default=0.30, ge=0.01, le=0.50, description="阳离子脂质/mRNA 重量比最大值")
cationic_mol_min: float = Field(default=0.05, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最小值")
cationic_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="阳离子脂质 mol 比例最大值")
phospholipid_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="磷脂 mol 比例最小值")
phospholipid_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="磷脂 mol 比例最大值")
cholesterol_mol_min: float = Field(default=0.00, ge=0.00, le=1.00, description="胆固醇 mol 比例最小值")
cholesterol_mol_max: float = Field(default=0.80, ge=0.00, le=1.00, description="胆固醇 mol 比例最大值")
peg_mol_min: float = Field(default=0.00, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最小值")
peg_mol_max: float = Field(default=0.05, ge=0.00, le=0.20, description="PEG 脂质 mol 比例最大值")
def to_comp_ranges(self) -> CompRanges:
"""转换为 CompRanges 对象"""
return CompRanges(
weight_ratio_min=self.weight_ratio_min,
weight_ratio_max=self.weight_ratio_max,
cationic_mol_min=self.cationic_mol_min,
cationic_mol_max=self.cationic_mol_max,
phospholipid_mol_min=self.phospholipid_mol_min,
phospholipid_mol_max=self.phospholipid_mol_max,
cholesterol_mol_min=self.cholesterol_mol_min,
cholesterol_mol_max=self.cholesterol_mol_max,
peg_mol_min=self.peg_mol_min,
peg_mol_max=self.peg_mol_max,
)
class ScoringWeightsRequest(BaseModel):
"""评分权重配置"""
biodist_weight: float = Field(default=1.0, ge=0.0, description="目标器官分布权重")
delivery_weight: float = Field(default=0.0, ge=0.0, description="量化递送权重")
size_weight: float = Field(default=0.0, ge=0.0, description="粒径权重 (80-150nm)")
ee_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0], description="EE 分类权重 [class0, class1, class2]")
pdi_class_weights: List[float] = Field(default=[0.0, 0.0, 0.0, 0.0], description="PDI 分类权重 [class0, class1, class2, class3]")
toxic_class_weights: List[float] = Field(default=[0.0, 0.0], description="毒性分类权重 [无毒, 有毒]")
def to_scoring_weights(self) -> ScoringWeights:
"""转换为 ScoringWeights 对象"""
return ScoringWeights(
biodist_weight=self.biodist_weight,
delivery_weight=self.delivery_weight,
size_weight=self.size_weight,
ee_class_weights=self.ee_class_weights,
pdi_class_weights=self.pdi_class_weights,
toxic_class_weights=self.toxic_class_weights,
)
class OptimizeRequest(BaseModel):
"""优化请求"""
smiles: str = Field(..., description="Cationic lipid SMILES string")
organ: str = Field(..., description="Target organ for optimization")
top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations")
top_k: int = Field(default=20, ge=1, le=100, description="Number of top formulations to return")
num_seeds: Optional[int] = Field(default=None, ge=1, le=500, description="Number of seed points from first iteration (default: top_k * 5)")
top_per_seed: int = Field(default=1, ge=1, le=10, description="Number of local best to keep per seed in refinement")
step_sizes: Optional[List[float]] = Field(default=None, description="Step sizes for each iteration (default: [0.10, 0.02, 0.01])")
comp_ranges: Optional[CompRangesRequest] = Field(default=None, description="组分范围配置(默认使用标准范围)")
routes: Optional[List[str]] = Field(default=None, description="给药途径列表 (default: ['intravenous', 'intramuscular'])")
scoring_weights: Optional[ScoringWeightsRequest] = Field(default=None, description="评分权重配置(默认仅按 biodist 排序)")
class Config:
json_schema_extra = {
"example": {
"smiles": "CC(C)NCCNC(C)C",
"organ": "liver",
"top_k": 20
"top_k": 20,
"num_seeds": None,
"top_per_seed": 1,
"step_sizes": None,
"comp_ranges": None,
"routes": None,
"scoring_weights": None
}
}
@ -48,6 +112,7 @@ class FormulationResult(BaseModel):
"""单个配方结果"""
rank: int
target_biodist: float
composite_score: Optional[float] = None # 综合评分
cationic_lipid_to_mrna_ratio: float
cationic_lipid_mol_ratio: float
phospholipid_mol_ratio: float
@ -56,6 +121,12 @@ class FormulationResult(BaseModel):
helper_lipid: str
route: str
all_biodist: Dict[str, float]
# 额外预测值
quantified_delivery: Optional[float] = None
size: Optional[float] = None
pdi_class: Optional[int] = None # PDI 分类 (0: <0.2, 1: 0.2-0.3, 2: 0.3-0.4, 3: >0.4)
ee_class: Optional[int] = None # EE 分类 (0: <80%, 1: 80-90%, 2: >90%)
toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒)
class OptimizeResponse(BaseModel):
@ -187,25 +258,65 @@ async def optimize_formulation(request: OptimizeRequest):
if not request.smiles or len(request.smiles.strip()) == 0:
raise HTTPException(status_code=400, detail="SMILES string cannot be empty")
logger.info(f"Optimization request: organ={request.organ}, smiles={request.smiles[:50]}...")
# 验证 routes
valid_routes = ["intravenous", "intramuscular"]
if request.routes is not None:
for r in request.routes:
if r not in valid_routes:
raise HTTPException(
status_code=400,
detail=f"Invalid route: {r}. Available: {valid_routes}"
)
if len(request.routes) == 0:
raise HTTPException(status_code=400, detail="At least one route must be specified")
logger.info(f"Optimization request: organ={request.organ}, routes={request.routes}, smiles={request.smiles[:50]}...")
# 构建组分范围配置(在 try 块外验证,确保返回 400 而非 500
comp_ranges = None
if request.comp_ranges is not None:
comp_ranges = request.comp_ranges.to_comp_ranges()
# 验证范围是否合理
validation_error = comp_ranges.get_validation_error()
if validation_error:
raise HTTPException(
status_code=400,
detail=f"组分范围配置无效: {validation_error}"
)
# 构建评分权重配置
scoring_weights = None
if request.scoring_weights is not None:
scoring_weights = request.scoring_weights.to_scoring_weights()
try:
# 执行优化
# 执行优化(层级搜索策略)
results = optimize(
smiles=request.smiles,
organ=request.organ,
model=state.model,
device=state.device,
top_k=request.top_k,
num_seeds=request.num_seeds,
top_per_seed=request.top_per_seed,
step_sizes=request.step_sizes,
comp_ranges=comp_ranges,
routes=request.routes,
scoring_weights=scoring_weights,
batch_size=256,
)
# 用于计算综合评分的权重
from app.optimize import compute_formulation_score, DEFAULT_SCORING_WEIGHTS
actual_scoring_weights = scoring_weights if scoring_weights is not None else DEFAULT_SCORING_WEIGHTS
# 转换结果
formulations = []
for i, f in enumerate(results):
formulations.append(FormulationResult(
rank=i + 1,
target_biodist=f.get_biodist(request.organ),
composite_score=compute_formulation_score(f, request.organ, actual_scoring_weights),
cationic_lipid_to_mrna_ratio=f.cationic_lipid_to_mrna_ratio,
cationic_lipid_mol_ratio=f.cationic_lipid_mol_ratio,
phospholipid_mol_ratio=f.phospholipid_mol_ratio,
@ -217,6 +328,12 @@ async def optimize_formulation(request: OptimizeRequest):
col.replace("Biodistribution_", ""): f.biodist_predictions.get(col, 0.0)
for col in TARGET_BIODIST
},
# 额外预测值
quantified_delivery=f.quantified_delivery,
size=f.size,
pdi_class=f.pdi_class,
ee_class=f.ee_class,
toxic_class=f.toxic_class,
))
logger.success(f"Optimization completed: {len(formulations)} formulations")

View File

@ -3,9 +3,13 @@ Streamlit 配方优化交互界面
启动应用:
streamlit run app/app.py
Docker 环境变量:
API_URL: API 服务地址 (默认: http://localhost:8000)
"""
import io
import os
from datetime import datetime
import httpx
@ -14,7 +18,8 @@ import streamlit as st
# ============ 配置 ============
API_URL = "http://localhost:8000"
# 从环境变量读取 API 地址,支持 Docker 环境
API_URL = os.environ.get("API_URL", "http://localhost:8000")
AVAILABLE_ORGANS = [
"liver",
@ -27,13 +32,23 @@ AVAILABLE_ORGANS = [
]
ORGAN_LABELS = {
"liver": "🫀 肝脏 (Liver)",
"spleen": "🟣 脾脏 (Spleen)",
"lung": "🫁 肺 (Lung)",
"heart": "❤️ 心脏 (Heart)",
"kidney": "🫘 肾脏 (Kidney)",
"muscle": "💪 肌肉 (Muscle)",
"lymph_nodes": "🔵 淋巴结 (Lymph Nodes)",
"liver": "肝脏 (Liver)",
"spleen": "脾脏 (Spleen)",
"lung": "肺 (Lung)",
"heart": "心脏 (Heart)",
"kidney": "肾脏 (Kidney)",
"muscle": "肌肉 (Muscle)",
"lymph_nodes": "淋巴结 (Lymph Nodes)",
}
AVAILABLE_ROUTES = [
"intravenous",
"intramuscular",
]
ROUTE_LABELS = {
"intravenous": "静脉注射 (Intravenous)",
"intramuscular": "肌肉注射 (Intramuscular)",
}
# ============ 页面配置 ============
@ -122,43 +137,107 @@ def check_api_status() -> bool:
return False
def call_optimize_api(smiles: str, organ: str, top_k: int = 20) -> dict:
def call_optimize_api(
smiles: str,
organ: str,
top_k: int = 20,
num_seeds: int = None,
top_per_seed: int = 1,
step_sizes: list = None,
comp_ranges: dict = None,
routes: list = None,
scoring_weights: dict = None,
) -> dict:
"""调用优化 API"""
with httpx.Client(timeout=300) as client: # 5 分钟超时
payload = {
"smiles": smiles,
"organ": organ,
"top_k": top_k,
"num_seeds": num_seeds,
"top_per_seed": top_per_seed,
"step_sizes": step_sizes,
"comp_ranges": comp_ranges,
"routes": routes,
"scoring_weights": scoring_weights,
}
with httpx.Client(timeout=600) as client: # 10 分钟超时(自定义参数可能需要更长时间)
response = client.post(
f"{API_URL}/optimize",
json={
"smiles": smiles,
"organ": organ,
"top_k": top_k,
},
json=payload,
)
response.raise_for_status()
return response.json()
def format_results_dataframe(results: dict) -> pd.DataFrame:
# PDI 分类标签
PDI_CLASS_LABELS = {
0: "<0.2 (优)",
1: "0.2-0.3 (良)",
2: "0.3-0.4 (中)",
3: ">0.4 (差)",
}
# EE 分类标签
EE_CLASS_LABELS = {
0: "<50% (低)",
1: "50-80% (中)",
2: ">80% (高)",
}
# 毒性分类标签
TOXIC_CLASS_LABELS = {
0: "无毒 ✓",
1: "有毒 ⚠",
}
def format_results_dataframe(results: dict, smiles_label: str = None) -> pd.DataFrame:
"""将 API 结果转换为 DataFrame"""
formulations = results["formulations"]
target_organ = results["target_organ"]
rows = []
for f in formulations:
row = {
row = {}
# 如果有 SMILES 标签,添加到首列
if smiles_label:
row["SMILES"] = smiles_label
row.update({
"排名": f["rank"],
f"Biodist_{target_organ}": f"{f['target_biodist']:.4f}",
"阳离子脂质/mRNA": f["cationic_lipid_to_mrna_ratio"],
"阳离子脂质(mol)": f["cationic_lipid_mol_ratio"],
"磷脂(mol)": f["phospholipid_mol_ratio"],
"胆固醇(mol)": f["cholesterol_mol_ratio"],
"PEG脂质(mol)": f["peg_lipid_mol_ratio"],
})
# 如果有综合评分,显示在排名后面
if f.get("composite_score") is not None:
row["综合评分"] = f"{f['composite_score']:.4f}"
row.update({
f"{target_organ}分布": f"{f['target_biodist']*100:.8f}%",
"阳离子脂质/mRNA比例": f["cationic_lipid_to_mrna_ratio"],
"阳离子脂质(mol)比例": f["cationic_lipid_mol_ratio"],
"磷脂(mol)比例": f["phospholipid_mol_ratio"],
"胆固醇(mol)比例": f["cholesterol_mol_ratio"],
"PEG脂质(mol)比例": f["peg_lipid_mol_ratio"],
"辅助脂质": f["helper_lipid"],
"给药途径": f["route"],
}
})
# 添加额外预测值
if f.get("quantified_delivery") is not None:
row["量化递送"] = f"{f['quantified_delivery']:.4f}"
if f.get("size") is not None:
row["粒径(nm)"] = f"{f['size']:.1f}"
if f.get("pdi_class") is not None:
row["PDI"] = PDI_CLASS_LABELS.get(f["pdi_class"], str(f["pdi_class"]))
if f.get("ee_class") is not None:
row["包封率"] = EE_CLASS_LABELS.get(f["ee_class"], str(f["ee_class"]))
if f.get("toxic_class") is not None:
row["毒性"] = TOXIC_CLASS_LABELS.get(f["toxic_class"], str(f["toxic_class"]))
# 添加其他器官的 biodist
for organ, value in f["all_biodist"].items():
if organ != target_organ:
row[f"Biodist_{organ}"] = f"{value:.4f}"
row[f"{organ}分布"] = f"{value*100:.2f}%"
rows.append(row)
return pd.DataFrame(rows)
@ -184,7 +263,7 @@ def main():
# ========== 侧边栏 ==========
with st.sidebar:
st.header("⚙️ 参数设置")
# st.header("⚙️ 参数设置")
# API 状态
if api_online:
@ -193,7 +272,7 @@ def main():
st.error("🔴 API 服务离线")
st.info("请先启动 API 服务:\n```\nuvicorn app.api:app --port 8000\n```")
st.divider()
# st.divider()
# SMILES 输入
st.subheader("🔬 分子结构")
@ -201,23 +280,23 @@ def main():
"输入阳离子脂质 SMILES",
value="",
height=100,
placeholder="例如: CC(C)NCCNC(C)C",
help="输入阳离子脂质的 SMILES 字符串",
placeholder="例如: CC(C)NCCNC(C)C\n多条SMILES用英文逗号分隔: SMI1,SMI2,SMI3",
help="输入阳离子脂质的 SMILES 字符串。支持多条 SMILES用英文逗号 (,) 分隔",
)
# 示例 SMILES
with st.expander("📋 示例 SMILES"):
example_smiles = {
"DLin-MC3-DMA": "CC(C)=CCCC(C)=CCCC(C)=CCN(C)CCCCCCCCOC(=O)CCCCCCC/C=C\\CCCCCCCC",
"简单胺": "CC(C)NCCNC(C)C",
"长链胺": "CCCCCCCCCCCCNCCNCCCCCCCCCCCC",
}
for name, smi in example_smiles.items():
if st.button(f"使用 {name}", key=f"example_{name}"):
st.session_state["smiles_input"] = smi
st.rerun()
# with st.expander("📋 示例 SMILES"):
# example_smiles = {
# "DLin-MC3-DMA": "CC(C)=CCCC(C)=CCCC(C)=CCN(C)CCCCCCCCOC(=O)CCCCCCC/C=C\\CCCCCCCC",
# "简单胺": "CC(C)NCCNC(C)C",
# "长链胺": "CCCCCCCCCCCCNCCNCCCCCCCCCCCC",
# }
# for name, smi in example_smiles.items():
# if st.button(f"使用 {name}", key=f"example_{name}"):
# st.session_state["smiles_input"] = smi
# st.rerun()
st.divider()
# st.divider()
# 目标器官选择
st.subheader("🎯 目标器官")
@ -228,17 +307,226 @@ def main():
index=0,
)
st.divider()
# 给药途径选择
st.subheader("💉 给药途径")
selected_routes = st.multiselect(
"选择给药途径",
options=AVAILABLE_ROUTES,
default=AVAILABLE_ROUTES,
format_func=lambda x: ROUTE_LABELS.get(x, x),
help="选择要搜索的给药途径,可多选。至少选择一种。",
)
if not selected_routes:
st.warning("⚠️ 请至少选择一种给药途径")
# 高级选项
with st.expander("🔧 高级选项"):
st.markdown("**输出设置**")
top_k = st.slider(
"返回配方数量",
"返回配方数量 (top_k)",
min_value=5,
max_value=50,
max_value=100,
value=20,
step=5,
help="最终返回的最优配方数量",
)
st.markdown("**搜索策略**")
num_seeds = st.slider(
"种子点数量 (num_seeds)",
min_value=10,
max_value=200,
value=top_k * 5,
step=10,
help="第一轮迭代后保留的种子点数量,更多种子点意味着更广泛的搜索",
)
top_per_seed = st.slider(
"每个种子的局部最优数 (top_per_seed)",
min_value=1,
max_value=5,
value=1,
step=1,
help="后续迭代中,每个种子点邻域保留的局部最优数量",
)
st.markdown("**迭代步长与轮数**")
use_custom_steps = st.checkbox(
"自定义迭代步长",
value=False,
help="默认步长为 [0.10, 0.02, 0.01]共3轮逐步精细化搜索。将某轮步长设为0可减少迭代轮数。",
)
if use_custom_steps:
col1, col2, col3 = st.columns(3)
with col1:
step1 = st.number_input(
"第1轮步长",
min_value=0.01, max_value=0.20, value=0.10,
step=0.01, format="%.2f",
help="第1轮为全局粗搜索步长必须大于0",
)
with col2:
step2 = st.number_input(
"第2轮步长",
min_value=0.00, max_value=0.10, value=0.02,
step=0.01, format="%.2f",
help="设为0则只进行1轮搜索",
)
with col3:
step3 = st.number_input(
"第3轮步长",
min_value=0.00, max_value=0.05, value=0.01,
step=0.01, format="%.2f",
help="设为0则只进行2轮搜索",
)
# 根据步长值构建实际的 step_sizes 列表
# step2 为 0 → 只保留 [step1]1轮
# step3 为 0 → 只保留 [step1, step2]2轮
# 都不为 0 → [step1, step2, step3]3轮
if step2 == 0.0:
step_sizes = [step1]
elif step3 == 0.0:
step_sizes = [step1, step2]
else:
step_sizes = [step1, step2, step3]
# 显示实际迭代轮数提示
st.caption(f"📌 实际迭代轮数: {len(step_sizes)} 轮,步长: {step_sizes}")
else:
step_sizes = None # 使用默认值
st.markdown("**组分范围限制**")
use_custom_ranges = st.checkbox(
"自定义组分取值范围",
value=False,
help="限制各组分的取值范围mol 比例加起来仍为 100%",
)
if use_custom_ranges:
st.caption("阳离子脂质/mRNA 重量比")
col1, col2 = st.columns(2)
with col1:
weight_ratio_min = st.number_input("最小", min_value=0.01, max_value=0.50, value=0.05, step=0.01, format="%.2f", key="wr_min")
with col2:
weight_ratio_max = st.number_input("最大", min_value=0.01, max_value=0.50, value=0.30, step=0.01, format="%.2f", key="wr_max")
st.caption("阳离子脂质 mol 比例")
col1, col2 = st.columns(2)
with col1:
cationic_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.05, step=0.05, format="%.2f", key="cat_min")
with col2:
cationic_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="cat_max")
st.caption("磷脂 mol 比例")
col1, col2 = st.columns(2)
with col1:
phospholipid_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="phos_min")
with col2:
phospholipid_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="phos_max")
st.caption("胆固醇 mol 比例")
col1, col2 = st.columns(2)
with col1:
cholesterol_mol_min = st.number_input("最小", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="chol_min")
with col2:
cholesterol_mol_max = st.number_input("最大", min_value=0.00, max_value=1.00, value=0.80, step=0.05, format="%.2f", key="chol_max")
st.caption("PEG 脂质 mol 比例")
col1, col2 = st.columns(2)
with col1:
peg_mol_min = st.number_input("最小", min_value=0.00, max_value=0.20, value=0.00, step=0.01, format="%.2f", key="peg_min")
with col2:
peg_mol_max = st.number_input("最大", min_value=0.00, max_value=0.20, value=0.05, step=0.01, format="%.2f", key="peg_max")
comp_ranges = {
"weight_ratio_min": weight_ratio_min,
"weight_ratio_max": weight_ratio_max,
"cationic_mol_min": cationic_mol_min,
"cationic_mol_max": cationic_mol_max,
"phospholipid_mol_min": phospholipid_mol_min,
"phospholipid_mol_max": phospholipid_mol_max,
"cholesterol_mol_min": cholesterol_mol_min,
"cholesterol_mol_max": cholesterol_mol_max,
"peg_mol_min": peg_mol_min,
"peg_mol_max": peg_mol_max,
}
# 简单验证
min_sum = cationic_mol_min + phospholipid_mol_min + cholesterol_mol_min + peg_mol_min
max_sum = cationic_mol_max + phospholipid_mol_max + cholesterol_mol_max + peg_mol_max
if min_sum > 1.0 or max_sum < 1.0:
st.warning("⚠️ 当前范围设置可能无法生成有效配方mol 比例需加起来为 100%")
else:
comp_ranges = None # 使用默认值
st.markdown("**评分/排序权重**")
use_custom_scoring = st.checkbox(
"自定义评分权重",
value=False,
help="默认仅按目标器官分布排序。开启后可自定义多目标加权评分,总分 = 各项score之和。",
)
if use_custom_scoring:
st.caption("**回归任务权重**")
sw_biodist = st.number_input(
"器官分布 (Biodistribution)",
min_value=0.00, max_value=10.00, value=0.30,
step=0.05, format="%.2f", key="sw_biodist",
help="score = biodist_value × weight",
)
sw_delivery = st.number_input(
"量化递送 (Quantified Delivery)",
min_value=0.00, max_value=10.00, value=0.25,
step=0.05, format="%.2f", key="sw_delivery",
help="score = normalize(delivery, route) × weight",
)
sw_size = st.number_input(
"粒径 (Size, 80-150nm)",
min_value=0.00, max_value=10.00, value=0.05,
step=0.05, format="%.2f", key="sw_size",
help="score = (1 if 60≤size≤150 else 0) × weight",
)
st.caption("**包封率 (EE) 分类权重**")
col1, col2, col3 = st.columns(3)
with col1:
sw_ee0 = st.number_input("<50% (低)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_ee0")
with col2:
sw_ee1 = st.number_input("50-80% (中)", min_value=0.00, max_value=1.00, value=0.02, step=0.01, format="%.2f", key="sw_ee1")
with col3:
sw_ee2 = st.number_input(">80% (高)", min_value=0.00, max_value=1.00, value=0.08, step=0.01, format="%.2f", key="sw_ee2")
st.caption("**PDI 分类权重**")
col1, col2, col3, col4 = st.columns(4)
with col1:
sw_pdi0 = st.number_input("<0.2 (优)", min_value=0.00, max_value=1.00, value=0.08, step=0.01, format="%.2f", key="sw_pdi0")
with col2:
sw_pdi1 = st.number_input("0.2-0.3 (良)", min_value=0.00, max_value=1.00, value=0.02, step=0.01, format="%.2f", key="sw_pdi1")
with col3:
sw_pdi2 = st.number_input("0.3-0.4 (中)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_pdi2")
with col4:
sw_pdi3 = st.number_input(">0.4 (差)", min_value=0.00, max_value=1.00, value=0.00, step=0.01, format="%.2f", key="sw_pdi3")
st.caption("**毒性分类权重**")
col1, col2 = st.columns(2)
with col1:
sw_toxic0 = st.number_input("无毒", min_value=0.00, max_value=1.00, value=0.20, step=0.05, format="%.2f", key="sw_toxic0")
with col2:
sw_toxic1 = st.number_input("有毒", min_value=0.00, max_value=1.00, value=0.00, step=0.05, format="%.2f", key="sw_toxic1")
scoring_weights = {
"biodist_weight": sw_biodist,
"delivery_weight": sw_delivery,
"size_weight": sw_size,
"ee_class_weights": [sw_ee0, sw_ee1, sw_ee2],
"pdi_class_weights": [sw_pdi0, sw_pdi1, sw_pdi2, sw_pdi3],
"toxic_class_weights": [sw_toxic0, sw_toxic1],
}
else:
scoring_weights = None # 使用默认值(仅按 biodist 排序)
st.divider()
@ -247,7 +535,7 @@ def main():
"🚀 开始配方优选",
type="primary",
use_container_width=True,
disabled=not api_online or not smiles_input.strip(),
disabled=not api_online or not smiles_input.strip() or not selected_routes,
)
# ========== 主内容区 ==========
@ -260,49 +548,125 @@ def main():
# 执行优化
if optimize_button and smiles_input.strip():
with st.spinner("🔄 正在优化配方,请稍候..."):
try:
results = call_optimize_api(
smiles=smiles_input.strip(),
organ=selected_organ,
top_k=top_k,
)
st.session_state["results"] = results
st.session_state["results_df"] = format_results_dataframe(results)
st.session_state["smiles_used"] = smiles_input.strip()
# 解析多条 SMILES用逗号分隔
smiles_list = [s.strip() for s in smiles_input.split(",") if s.strip()]
if not smiles_list:
st.error("❌ 请输入有效的 SMILES 字符串")
else:
is_multi_smiles = len(smiles_list) > 1
all_results = []
all_dfs = []
errors = []
# 进度条
progress_bar = st.progress(0)
status_text = st.empty()
for idx, smiles in enumerate(smiles_list):
status_text.text(f"🔄 正在优化 SMILES {idx + 1}/{len(smiles_list)}...")
progress_bar.progress((idx) / len(smiles_list))
try:
results = call_optimize_api(
smiles=smiles,
organ=selected_organ,
top_k=top_k,
num_seeds=num_seeds,
top_per_seed=top_per_seed,
step_sizes=step_sizes,
comp_ranges=comp_ranges,
routes=selected_routes,
scoring_weights=scoring_weights,
)
all_results.append({"smiles": smiles, "results": results})
# 为多 SMILES 模式添加 SMILES 标签
smiles_label = smiles[:30] + "..." if len(smiles) > 30 else smiles
df = format_results_dataframe(results, smiles_label if is_multi_smiles else None)
all_dfs.append(df)
except httpx.HTTPStatusError as e:
try:
error_detail = e.response.json().get("detail", str(e))
except:
error_detail = str(e)
errors.append(f"SMILES {idx + 1}: {error_detail}")
except httpx.RequestError as e:
errors.append(f"SMILES {idx + 1}: API 连接失败 - {e}")
except Exception as e:
errors.append(f"SMILES {idx + 1}: {e}")
progress_bar.progress(1.0)
status_text.empty()
progress_bar.empty()
# 显示错误
for err in errors:
st.error(f"{err}")
# 保存结果
if all_results:
st.session_state["results"] = all_results[0]["results"] if len(all_results) == 1 else all_results
st.session_state["results_df"] = pd.concat(all_dfs, ignore_index=True) if all_dfs else None
st.session_state["smiles_used"] = smiles_list
st.session_state["organ_used"] = selected_organ
st.success("✅ 优化完成!")
except httpx.RequestError as e:
st.error(f"❌ API 请求失败: {e}")
except Exception as e:
st.error(f"❌ 发生错误: {e}")
st.session_state["is_multi_smiles"] = is_multi_smiles
st.success(f"✅ 优化完成!成功处理 {len(all_results)}/{len(smiles_list)} 条 SMILES")
# 显示结果
if st.session_state["results"] is not None:
if st.session_state["results"] is not None and st.session_state["results_df"] is not None:
results = st.session_state["results"]
df = st.session_state["results_df"]
is_multi_smiles = st.session_state.get("is_multi_smiles", False)
# 结果概览
col1, col2, col3 = st.columns(3)
with col1:
st.metric(
"目标器官",
ORGAN_LABELS.get(results["target_organ"], results["target_organ"]).split(" ")[0],
)
with col2:
best_score = results["formulations"][0]["target_biodist"]
st.metric(
"最优 Biodistribution",
f"{best_score:.4f}",
)
with col3:
st.metric(
"优选配方数",
len(results["formulations"]),
)
if is_multi_smiles:
# 多 SMILES 模式
col1, col2, col3 = st.columns(3)
with col1:
# 获取 target_organ从第一个结果
first_result = results[0]["results"] if isinstance(results, list) else results
target_organ = first_result["target_organ"]
st.metric(
"目标器官",
ORGAN_LABELS.get(target_organ, target_organ).split(" ")[0],
)
with col2:
st.metric(
"SMILES 数量",
len(results) if isinstance(results, list) else 1,
)
with col3:
st.metric(
"总配方数",
len(df),
)
else:
# 单 SMILES 模式
col1, col2, col3 = st.columns(3)
with col1:
st.metric(
"目标器官",
ORGAN_LABELS.get(results["target_organ"], results["target_organ"]).split(" ")[0],
)
with col2:
best_score = results["formulations"][0]["target_biodist"]
st.metric(
"最优分布",
f"{best_score*100:.2f}%",
)
with col3:
st.metric(
"优选配方数",
len(results["formulations"]),
)
st.divider()
@ -312,15 +676,26 @@ def main():
# 导出按钮行
col_export, col_spacer = st.columns([1, 4])
with col_export:
smiles_used = st.session_state.get("smiles_used", "")
if isinstance(smiles_used, list):
smiles_used = ",".join(smiles_used)
csv_content = create_export_csv(
df,
st.session_state.get("smiles_used", ""),
smiles_used,
st.session_state.get("organ_used", ""),
)
# 获取 target_organ
if is_multi_smiles:
target_organ = results[0]["results"]["target_organ"] if isinstance(results, list) else results["target_organ"]
else:
target_organ = results["target_organ"]
st.download_button(
label="📥 导出 CSV",
data=csv_content,
file_name=f"lnp_optimization_{results['target_organ']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
file_name=f"lnp_optimization_{target_organ}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
mime="text/csv",
)
@ -333,61 +708,61 @@ def main():
)
# 详细信息
with st.expander("🔍 查看最优配方详情"):
best = results["formulations"][0]
# with st.expander("🔍 查看最优配方详情"):
# best = results["formulations"][0]
col1, col2 = st.columns(2)
# col1, col2 = st.columns(2)
with col1:
st.markdown("**配方参数**")
st.json({
"阳离子脂质/mRNA 比例": best["cationic_lipid_to_mrna_ratio"],
"阳离子脂质 (mol%)": best["cationic_lipid_mol_ratio"],
"磷脂 (mol%)": best["phospholipid_mol_ratio"],
"胆固醇 (mol%)": best["cholesterol_mol_ratio"],
"PEG 脂质 (mol%)": best["peg_lipid_mol_ratio"],
"辅助脂质": best["helper_lipid"],
"给药途径": best["route"],
})
# with col1:
# st.markdown("**配方参数**")
# st.json({
# "阳离子脂质/mRNA 比例": best["cationic_lipid_to_mrna_ratio"],
# "阳离子脂质 (mol%)": best["cationic_lipid_mol_ratio"],
# "磷脂 (mol%)": best["phospholipid_mol_ratio"],
# "胆固醇 (mol%)": best["cholesterol_mol_ratio"],
# "PEG 脂质 (mol%)": best["peg_lipid_mol_ratio"],
# "辅助脂质": best["helper_lipid"],
# "给药途径": best["route"],
# })
with col2:
st.markdown("**各器官 Biodistribution 预测**")
biodist_df = pd.DataFrame([
{"器官": ORGAN_LABELS.get(k, k), "Biodistribution": f"{v:.4f}"}
for k, v in best["all_biodist"].items()
])
st.dataframe(biodist_df, hide_index=True, use_container_width=True)
# with col2:
# st.markdown("**各器官 Biodistribution 预测**")
# biodist_df = pd.DataFrame([
# {"器官": ORGAN_LABELS.get(k, k), "Biodistribution": f"{v:.4f}"}
# for k, v in best["all_biodist"].items()
# ])
# st.dataframe(biodist_df, hide_index=True, use_container_width=True)
else:
# 欢迎信息
st.info("👈 请在左侧输入 SMILES 并选择目标器官,然后点击「开始配方优选」")
# 使用说明
with st.expander("📖 使用说明"):
st.markdown("""
### 如何使用
# with st.expander("📖 使用说明"):
# st.markdown("""
# ### 如何使用
1. **输入 SMILES**: 在左侧输入框中输入阳离子脂质的 SMILES 字符串
2. **选择目标器官**: 选择您希望优化的器官靶向
3. **点击优选**: 系统将自动搜索最优配方组合
4. **查看结果**: 右侧将显示 Top-20 优选配方
5. **导出数据**: 点击导出按钮将结果保存为 CSV 文件
# 1. **输入 SMILES**: 在左侧输入框中输入阳离子脂质的 SMILES 字符串
# 2. **选择目标器官**: 选择您希望优化的器官靶向
# 3. **点击优选**: 系统将自动搜索最优配方组合
# 4. **查看结果**: 右侧将显示 Top-20 优选配方
# 5. **导出数据**: 点击导出按钮将结果保存为 CSV 文件
### 优化参数
# ### 优化参数
系统会优化以下配方参数:
- **阳离子脂质/mRNA 比例**: 0.05 - 0.30
- **阳离子脂质 mol 比例**: 0.05 - 0.80
- **磷脂 mol 比例**: 0.00 - 0.80
- **胆固醇 mol 比例**: 0.00 - 0.80
- **PEG 脂质 mol 比例**: 0.00 - 0.05
- **辅助脂质**: DOPE / DSPC / DOTAP
- **给药途径**: 静脉注射 / 肌肉注射
# 系统会优化以下配方参数:
# - **阳离子脂质/mRNA 比例**: 0.05 - 0.30
# - **阳离子脂质 mol 比例**: 0.05 - 0.80
# - **磷脂 mol 比例**: 0.00 - 0.80
# - **胆固醇 mol 比例**: 0.00 - 0.80
# - **PEG 脂质 mol 比例**: 0.00 - 0.05
# - **辅助脂质**: DOPE / DSPC / DOTAP
# - **给药途径**: 静脉注射 / 肌肉注射
### 约束条件
# ### 约束条件
mol 比例之和 = 1 (阳离子脂质 + 磷脂 + 胆固醇 + PEG 脂质)
""")
# mol 比例之和 = 1 (阳离子脂质 + 磷脂 + 胆固醇 + PEG 脂质)
# """)
if __name__ == "__main__":

View File

@ -41,14 +41,73 @@ app = typer.Typer()
# 可用的目标器官
AVAILABLE_ORGANS = ["lymph_nodes", "heart", "liver", "spleen", "lung", "kidney", "muscle"]
# comp token 参数范围
COMP_PARAM_RANGES = {
"Cationic_Lipid_to_mRNA_weight_ratio": (0.05, 0.30),
"Cationic_Lipid_Mol_Ratio": (0.05, 0.80),
"Phospholipid_Mol_Ratio": (0.00, 0.80),
"Cholesterol_Mol_Ratio": (0.00, 0.80),
"PEG_Lipid_Mol_Ratio": (0.00, 0.05),
}
@dataclass
class CompRanges:
"""组分参数范围配置"""
# 阳离子脂质/mRNA 重量比
weight_ratio_min: float = 0.05
weight_ratio_max: float = 0.30
# 阳离子脂质 mol 比例
cationic_mol_min: float = 0.05
cationic_mol_max: float = 0.80
# 磷脂 mol 比例
phospholipid_mol_min: float = 0.00
phospholipid_mol_max: float = 0.80
# 胆固醇 mol 比例
cholesterol_mol_min: float = 0.00
cholesterol_mol_max: float = 0.80
# PEG 脂质 mol 比例
peg_mol_min: float = 0.00
peg_mol_max: float = 0.05
def to_dict(self) -> Dict:
"""转换为字典"""
return {
"weight_ratio": (self.weight_ratio_min, self.weight_ratio_max),
"cationic_mol": (self.cationic_mol_min, self.cationic_mol_max),
"phospholipid_mol": (self.phospholipid_mol_min, self.phospholipid_mol_max),
"cholesterol_mol": (self.cholesterol_mol_min, self.cholesterol_mol_max),
"peg_mol": (self.peg_mol_min, self.peg_mol_max),
}
def get_validation_error(self) -> Optional[str]:
"""
验证范围是否合理返回错误信息如果有
Returns:
错误信息字符串如果验证通过则返回 None
"""
# 检查各范围是否有效(最小值不能大于最大值)
if self.weight_ratio_min > self.weight_ratio_max:
return f"阳离子脂质/mRNA重量比最小值({self.weight_ratio_min})不能大于最大值({self.weight_ratio_max})"
if self.cationic_mol_min > self.cationic_mol_max:
return f"阳离子脂质mol比例最小值({self.cationic_mol_min})不能大于最大值({self.cationic_mol_max})"
if self.phospholipid_mol_min > self.phospholipid_mol_max:
return f"磷脂mol比例最小值({self.phospholipid_mol_min})不能大于最大值({self.phospholipid_mol_max})"
if self.cholesterol_mol_min > self.cholesterol_mol_max:
return f"胆固醇mol比例最小值({self.cholesterol_mol_min})不能大于最大值({self.cholesterol_mol_max})"
if self.peg_mol_min > self.peg_mol_max:
return f"PEG脂质mol比例最小值({self.peg_mol_min})不能大于最大值({self.peg_mol_max})"
# 检查 mol ratio 是否可能加起来为 1
min_sum = self.cationic_mol_min + self.phospholipid_mol_min + self.cholesterol_mol_min + self.peg_mol_min
max_sum = self.cationic_mol_max + self.phospholipid_mol_max + self.cholesterol_mol_max + self.peg_mol_max
if min_sum > 1.0:
return f"mol比例最小值之和({min_sum:.2f})超过100%,无法生成有效配方"
if max_sum < 1.0:
return f"mol比例最大值之和({max_sum:.2f})不足100%,无法生成有效配方"
return None
def validate(self) -> bool:
"""验证范围是否合理(至少存在一个可行解)"""
return self.get_validation_error() is None
# 默认组分范围
DEFAULT_COMP_RANGES = CompRanges()
# 最小 step size
MIN_STEP_SIZE = 0.01
@ -56,12 +115,153 @@ MIN_STEP_SIZE = 0.01
# 迭代策略:每个迭代的 step_size
ITERATION_STEP_SIZES = [0.10, 0.02, 0.01]
# Helper lipid 选项
HELPER_LIPID_OPTIONS = ["DOPE", "DSPC", "DOTAP"]
# Helper lipid 选项(不包含 DOTAP
HELPER_LIPID_OPTIONS = ["DOPE", "DSPC"]
# Route of administration 选项
ROUTE_OPTIONS = ["intravenous", "intramuscular"]
# quantified_delivery 归一化常量(按给药途径)
DELIVERY_NORM = {
"intravenous": {"min": -0.798559291, "max": 4.497814051056962},
"intramuscular": {"min": -0.794912427, "max": 10.220042980012716},
}
@dataclass
class ScoringWeights:
"""
评分权重配置
总分 = biodist_score + delivery_score + size_score + ee_score + pdi_score + toxic_score
各项计算方式参见 SCORE.md
默认值仅按目标器官 biodistribution 排序向后兼容
"""
# 回归任务权重
biodist_weight: float = 1.0 # score = biodist_value * weight
delivery_weight: float = 0.0 # score = normalize(delivery, route) * weight
size_weight: float = 0.0 # score = (1 if 80<=size<=150 else 0) * weight
# 分类任务per-class 权重(预测为该类时,得分 = 对应权重)
ee_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) # EE class 0, 1, 2
pdi_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0, 0.0, 0.0]) # PDI class 0, 1, 2, 3
toxic_class_weights: List[float] = field(default_factory=lambda: [0.0, 0.0]) # Toxic class 0, 1
# 默认评分权重(仅按 biodist 排序)
DEFAULT_SCORING_WEIGHTS = ScoringWeights()
def compute_formulation_score(
f: 'Formulation',
organ: str,
weights: ScoringWeights,
) -> float:
"""
计算单个 Formulation 的综合评分
Args:
f: 配方对象
organ: 目标器官
weights: 评分权重
Returns:
综合评分
"""
score = 0.0
# 1. biodistribution
score += f.get_biodist(organ) * weights.biodist_weight
# 2. quantified_delivery按给药途径归一化到 [0, 1]
if f.quantified_delivery is not None and weights.delivery_weight != 0:
norm = DELIVERY_NORM.get(f.route, DELIVERY_NORM["intravenous"])
d_range = norm["max"] - norm["min"]
if d_range > 0:
delivery_normalized = (f.quantified_delivery - norm["min"]) / d_range
delivery_normalized = max(0.0, min(1.0, delivery_normalized))
else:
delivery_normalized = 0.0
score += delivery_normalized * weights.delivery_weight
# 3. size60-150nm 为理想范围)
if f.size is not None and weights.size_weight != 0:
if 60 <= f.size <= 150:
score += 1.0 * weights.size_weight
# 4. EE 分类
if f.ee_class is not None and 0 <= f.ee_class < len(weights.ee_class_weights):
score += weights.ee_class_weights[f.ee_class]
# 5. PDI 分类
if f.pdi_class is not None and 0 <= f.pdi_class < len(weights.pdi_class_weights):
score += weights.pdi_class_weights[f.pdi_class]
# 6. 毒性分类
if f.toxic_class is not None and 0 <= f.toxic_class < len(weights.toxic_class_weights):
score += weights.toxic_class_weights[f.toxic_class]
return score
def compute_df_score(
df: pd.DataFrame,
organ: str,
weights: ScoringWeights,
) -> pd.Series:
"""
DataFrame 中的所有行计算综合评分
Args:
df: 包含预测结果的 DataFrame
organ: 目标器官
weights: 评分权重
Returns:
评分 Series
"""
score = pd.Series(0.0, index=df.index)
# 1. biodistribution
pred_col = f"pred_Biodistribution_{organ}"
if pred_col in df.columns:
score += df[pred_col] * weights.biodist_weight
# 2. quantified_delivery按给药途径归一化
if weights.delivery_weight != 0 and "pred_delivery" in df.columns:
for route_name, norm in DELIVERY_NORM.items():
mask = df["_route"] == route_name
if mask.any():
d_range = norm["max"] - norm["min"]
if d_range > 0:
delivery_normalized = (df.loc[mask, "pred_delivery"] - norm["min"]) / d_range
delivery_normalized = delivery_normalized.clip(0.0, 1.0)
score.loc[mask] += delivery_normalized * weights.delivery_weight
# 3. size
if weights.size_weight != 0 and "pred_size" in df.columns:
size_ok = (df["pred_size"] >= 60) & (df["pred_size"] <= 150)
score += size_ok.astype(float) * weights.size_weight
# 4. EE 分类
if "pred_ee_class" in df.columns:
for cls, w in enumerate(weights.ee_class_weights):
if w != 0:
score += (df["pred_ee_class"] == cls).astype(float) * w
# 5. PDI 分类
if "pred_pdi_class" in df.columns:
for cls, w in enumerate(weights.pdi_class_weights):
if w != 0:
score += (df["pred_pdi_class"] == cls).astype(float) * w
# 6. 毒性分类
if "pred_toxic_class" in df.columns:
for cls, w in enumerate(weights.toxic_class_weights):
if w != 0:
score += (df["pred_toxic_class"] == cls).astype(float) * w
return score
@dataclass
class Formulation:
@ -77,6 +277,12 @@ class Formulation:
route: str = "intravenous"
# 预测结果(填充后)
biodist_predictions: Dict[str, float] = field(default_factory=dict)
# 额外预测值
quantified_delivery: Optional[float] = None
size: Optional[float] = None
pdi_class: Optional[int] = None # PDI 分类 (0-3)
ee_class: Optional[int] = None # EE 分类 (0-2)
toxic_class: Optional[int] = None # 毒性分类 (0: 无毒, 1: 有毒)
def to_dict(self) -> Dict:
"""转换为字典"""
@ -94,6 +300,18 @@ class Formulation:
"""获取指定器官的 biodistribution 预测值"""
col = f"Biodistribution_{organ}"
return self.biodist_predictions.get(col, 0.0)
def unique_key(self) -> tuple:
"""生成唯一标识键,用于去重"""
return (
round(self.cationic_lipid_to_mrna_ratio, 4),
round(self.cationic_lipid_mol_ratio, 4),
round(self.phospholipid_mol_ratio, 4),
round(self.cholesterol_mol_ratio, 4),
round(self.peg_lipid_mol_ratio, 4),
self.helper_lipid,
self.route,
)
def generate_grid_values(
@ -124,20 +342,38 @@ def generate_grid_values(
return sorted(set(values))
def generate_initial_grid(step_size: float) -> List[Tuple[float, float, float, float, float]]:
def generate_initial_grid(
step_size: float,
comp_ranges: CompRanges = None,
) -> List[Tuple[float, float, float, float, float]]:
"""
生成初始搜索网格满足 mol ratio 和为 1 的约束
Args:
step_size: 搜索步长
comp_ranges: 组分范围配置默认使用 DEFAULT_COMP_RANGES
Returns:
List of (cationic_ratio, cationic_mol, phospholipid_mol, cholesterol_mol, peg_mol)
"""
if comp_ranges is None:
comp_ranges = DEFAULT_COMP_RANGES
grid = []
# Cationic_Lipid_to_mRNA_weight_ratio
weight_ratios = np.arange(0.05, 0.31, step_size)
weight_ratios = np.arange(
comp_ranges.weight_ratio_min,
comp_ranges.weight_ratio_max + 0.001,
step_size
)
# PEG: 单独处理,范围很小
peg_values = np.arange(0.00, 0.06, MIN_STEP_SIZE) # PEG 始终用 0.01 步长
# PEG: 单独处理,范围很小,始终用最小步长
peg_values = np.arange(
comp_ranges.peg_mol_min,
comp_ranges.peg_mol_max + 0.001,
MIN_STEP_SIZE
)
# 其他三个 mol ratio 需要满足和为 1 - PEG
mol_step = step_size
@ -146,11 +382,13 @@ def generate_initial_grid(step_size: float) -> List[Tuple[float, float, float, f
for peg in peg_values:
remaining = 1.0 - peg
# 生成满足约束的组合
for cationic_mol in np.arange(0.05, min(0.81, remaining + 0.001), mol_step):
for phospholipid_mol in np.arange(0.00, min(0.81, remaining - cationic_mol + 0.001), mol_step):
cationic_max = min(comp_ranges.cationic_mol_max, remaining) + 0.001
for cationic_mol in np.arange(comp_ranges.cationic_mol_min, cationic_max, mol_step):
phospholipid_max = min(comp_ranges.phospholipid_mol_max, remaining - cationic_mol) + 0.001
for phospholipid_mol in np.arange(comp_ranges.phospholipid_mol_min, phospholipid_max, mol_step):
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
# 检查约束
if 0.00 <= cholesterol_mol <= 0.80:
if (comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max):
grid.append((
round(weight_ratio, 4),
round(cationic_mol, 4),
@ -166,6 +404,7 @@ def generate_refined_grid(
seeds: List[Formulation],
step_size: float,
radius: int = 2,
comp_ranges: CompRanges = None,
) -> List[Tuple[float, float, float, float, float]]:
"""
围绕种子点生成精细化网格
@ -174,29 +413,37 @@ def generate_refined_grid(
seeds: 种子配方列表
step_size: 步长
radius: 扩展半径
comp_ranges: 组分范围配置默认使用 DEFAULT_COMP_RANGES
Returns:
新的网格点列表
"""
if comp_ranges is None:
comp_ranges = DEFAULT_COMP_RANGES
grid_set = set()
for seed in seeds:
# Weight ratio
weight_ratios = generate_grid_values(
seed.cationic_lipid_to_mrna_ratio, step_size, 0.05, 0.30, radius
seed.cationic_lipid_to_mrna_ratio, step_size,
comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius
)
# PEG (始终用最小步长)
peg_values = generate_grid_values(
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE, 0.00, 0.05, radius
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE,
comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius
)
# Mol ratios
cationic_mols = generate_grid_values(
seed.cationic_lipid_mol_ratio, step_size, 0.05, 0.80, radius
seed.cationic_lipid_mol_ratio, step_size,
comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius
)
phospholipid_mols = generate_grid_values(
seed.phospholipid_mol_ratio, step_size, 0.00, 0.80, radius
seed.phospholipid_mol_ratio, step_size,
comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius
)
for weight_ratio in weight_ratios:
@ -206,10 +453,10 @@ def generate_refined_grid(
for phospholipid_mol in phospholipid_mols:
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
# 检查约束
if (0.05 <= cationic_mol <= 0.80 and
0.00 <= phospholipid_mol <= 0.80 and
0.00 <= cholesterol_mol <= 0.80 and
0.00 <= peg <= 0.05):
if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and
comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and
comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and
comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max):
grid_set.add((
round(weight_ratio, 4),
round(cationic_mol, 4),
@ -283,14 +530,14 @@ def create_dataframe_from_formulations(
return pd.DataFrame(rows)
def predict_biodist(
def predict_all(
model: torch.nn.Module,
df: pd.DataFrame,
device: torch.device,
batch_size: int = 256,
) -> pd.DataFrame:
"""
使用模型预测 biodistribution
使用模型预测所有输出biodistributionsizedeliveryPDIEE
Returns:
添加了预测列的 DataFrame
@ -301,6 +548,11 @@ def predict_biodist(
)
all_biodist_preds = []
all_size_preds = []
all_delivery_preds = []
all_pdi_preds = []
all_ee_preds = []
all_toxic_preds = []
with torch.no_grad():
for batch in dataloader:
@ -312,20 +564,53 @@ def predict_biodist(
# biodist 输出是 softmax 后的概率分布 [B, 7]
biodist_pred = outputs["biodist"].cpu().numpy()
all_biodist_preds.append(biodist_pred)
# size 和 delivery 是回归值
all_size_preds.append(outputs["size"].squeeze(-1).cpu().numpy())
all_delivery_preds.append(outputs["delivery"].squeeze(-1).cpu().numpy())
# PDI、EE 和 toxic 是分类,取 argmax
all_pdi_preds.append(outputs["pdi"].argmax(dim=-1).cpu().numpy())
all_ee_preds.append(outputs["ee"].argmax(dim=-1).cpu().numpy())
all_toxic_preds.append(outputs["toxic"].argmax(dim=-1).cpu().numpy())
biodist_preds = np.concatenate(all_biodist_preds, axis=0)
size_preds = np.concatenate(all_size_preds, axis=0)
delivery_preds = np.concatenate(all_delivery_preds, axis=0)
pdi_preds = np.concatenate(all_pdi_preds, axis=0)
ee_preds = np.concatenate(all_ee_preds, axis=0)
toxic_preds = np.concatenate(all_toxic_preds, axis=0)
# 添加到 DataFrame
for i, col in enumerate(TARGET_BIODIST):
df[f"pred_{col}"] = biodist_preds[:, i]
# size 模型输出为 log(size),转换回真实粒径 (nm)
df["pred_size"] = np.exp(size_preds)
df["pred_delivery"] = delivery_preds
df["pred_pdi_class"] = pdi_preds
df["pred_ee_class"] = ee_preds
df["pred_toxic_class"] = toxic_preds
return df
# 保持向后兼容
def predict_biodist(
model: torch.nn.Module,
df: pd.DataFrame,
device: torch.device,
batch_size: int = 256,
) -> pd.DataFrame:
"""向后兼容的别名"""
return predict_all(model, df, device, batch_size)
def select_top_k(
df: pd.DataFrame,
organ: str,
k: int = 20,
scoring_weights: Optional[ScoringWeights] = None,
) -> List[Formulation]:
"""
选择 top-k 配方
@ -334,16 +619,18 @@ def select_top_k(
df: 包含预测结果的 DataFrame
organ: 目标器官
k: 选择数量
scoring_weights: 评分权重默认仅按 biodist 排序
Returns:
Top-k 配方列表
"""
pred_col = f"pred_Biodistribution_{organ}"
if pred_col not in df.columns:
raise ValueError(f"Prediction column {pred_col} not found")
if scoring_weights is None:
scoring_weights = DEFAULT_SCORING_WEIGHTS
# 排序并去重
df_sorted = df.sort_values(pred_col, ascending=False)
# 计算综合评分并排序
df = df.copy()
df["_composite_score"] = compute_df_score(df, organ, scoring_weights)
df_sorted = df.sort_values("_composite_score", ascending=False)
# 创建配方对象
formulations = []
@ -373,6 +660,12 @@ def select_top_k(
biodist_predictions={
col: row[f"pred_{col}"] for col in TARGET_BIODIST
},
# 额外预测值
quantified_delivery=row.get("pred_delivery"),
size=row.get("pred_size"),
pdi_class=int(row.get("pred_pdi_class")) if row.get("pred_pdi_class") is not None else None,
ee_class=int(row.get("pred_ee_class")) if row.get("pred_ee_class") is not None else None,
toxic_class=int(row.get("pred_toxic_class")) if row.get("pred_toxic_class") is not None else None,
)
formulations.append(formulation)
@ -382,72 +675,242 @@ def select_top_k(
return formulations
def generate_single_seed_grid(
seed: Formulation,
step_size: float,
radius: int = 2,
comp_ranges: CompRanges = None,
) -> List[Tuple[float, float, float, float, float]]:
"""
为单个种子点生成邻域网格
Args:
seed: 种子配方
step_size: 步长
radius: 扩展半径
comp_ranges: 组分范围配置默认使用 DEFAULT_COMP_RANGES
Returns:
网格点列表
"""
if comp_ranges is None:
comp_ranges = DEFAULT_COMP_RANGES
grid_set = set()
# Weight ratio
weight_ratios = generate_grid_values(
seed.cationic_lipid_to_mrna_ratio, step_size,
comp_ranges.weight_ratio_min, comp_ranges.weight_ratio_max, radius
)
# PEG (始终用最小步长)
peg_values = generate_grid_values(
seed.peg_lipid_mol_ratio, MIN_STEP_SIZE,
comp_ranges.peg_mol_min, comp_ranges.peg_mol_max, radius
)
# Mol ratios
cationic_mols = generate_grid_values(
seed.cationic_lipid_mol_ratio, step_size,
comp_ranges.cationic_mol_min, comp_ranges.cationic_mol_max, radius
)
phospholipid_mols = generate_grid_values(
seed.phospholipid_mol_ratio, step_size,
comp_ranges.phospholipid_mol_min, comp_ranges.phospholipid_mol_max, radius
)
for weight_ratio in weight_ratios:
for peg in peg_values:
remaining = 1.0 - peg
for cationic_mol in cationic_mols:
for phospholipid_mol in phospholipid_mols:
cholesterol_mol = remaining - cationic_mol - phospholipid_mol
# 检查约束
if (comp_ranges.cationic_mol_min <= cationic_mol <= comp_ranges.cationic_mol_max and
comp_ranges.phospholipid_mol_min <= phospholipid_mol <= comp_ranges.phospholipid_mol_max and
comp_ranges.cholesterol_mol_min <= cholesterol_mol <= comp_ranges.cholesterol_mol_max and
comp_ranges.peg_mol_min <= peg <= comp_ranges.peg_mol_max):
grid_set.add((
round(weight_ratio, 4),
round(cationic_mol, 4),
round(phospholipid_mol, 4),
round(cholesterol_mol, 4),
round(peg, 4),
))
return list(grid_set)
def optimize(
smiles: str,
organ: str,
model: torch.nn.Module,
device: torch.device,
top_k: int = 20,
num_seeds: Optional[int] = None,
top_per_seed: int = 1,
step_sizes: Optional[List[float]] = None,
comp_ranges: Optional[CompRanges] = None,
routes: Optional[List[str]] = None,
scoring_weights: Optional[ScoringWeights] = None,
batch_size: int = 256,
) -> List[Formulation]:
"""
执行配方优化
执行配方优化层级搜索策略
采用层级搜索策略
1. 第一次迭代全局稀疏搜索选择 top num_seeds 个分散的种子点
2. 后续迭代对每个种子点分别在其邻域内搜索各自保留 top_per_seed 个局部最优
3. 这样可以保持搜索的多样性避免结果集中在单一区域
Args:
smiles: SMILES 字符串
organ: 目标器官
model: 训练好的模型
device: 计算设备
top_k: 每轮保留的最优配方数
top_k: 最终返回的最优配方数
num_seeds: 第一次迭代后保留的种子点数量默认为 top_k * 5
top_per_seed: 每个种子点的邻域搜索后保留的局部最优点数量
step_sizes: 每轮迭代的步长列表默认为 [0.10, 0.02, 0.01]
comp_ranges: 组分范围配置默认使用 DEFAULT_COMP_RANGES
routes: 给药途径列表默认使用 ROUTE_OPTIONS
scoring_weights: 评分权重配置默认仅按 biodist 排序
batch_size: 预测批次大小
Returns:
最终 top-k 配方列表
"""
# 默认 num_seeds 为 top_k * 5
if num_seeds is None:
num_seeds = top_k * 5
# 默认步长
if step_sizes is None:
step_sizes = ITERATION_STEP_SIZES
# 默认组分范围
if comp_ranges is None:
comp_ranges = DEFAULT_COMP_RANGES
# 默认给药途径
if routes is None:
routes = ROUTE_OPTIONS
# 默认评分权重
if scoring_weights is None:
scoring_weights = DEFAULT_SCORING_WEIGHTS
# 评分函数(用于 Formulation 对象排序)
def _score(f: Formulation) -> float:
return compute_formulation_score(f, organ, scoring_weights)
logger.info(f"Starting optimization for organ: {organ}")
logger.info(f"SMILES: {smiles}")
logger.info(f"Strategy: num_seeds={num_seeds}, top_per_seed={top_per_seed}, top_k={top_k}")
logger.info(f"Step sizes: {step_sizes}")
logger.info(f"Routes: {routes}")
logger.info(f"Scoring weights: biodist={scoring_weights.biodist_weight}, delivery={scoring_weights.delivery_weight}, size={scoring_weights.size_weight}")
logger.info(f"Comp ranges: {comp_ranges.to_dict()}")
seeds = None
for iteration, step_size in enumerate(ITERATION_STEP_SIZES):
for iteration, step_size in enumerate(step_sizes):
logger.info(f"\n{'='*60}")
logger.info(f"Iteration {iteration + 1}/{len(ITERATION_STEP_SIZES)}, step_size={step_size}")
logger.info(f"Iteration {iteration + 1}/{len(step_sizes)}, step_size={step_size}")
logger.info(f"{'='*60}")
# 生成网格
if seeds is None:
# 第一次迭代:生成完整初始网格
logger.info("Generating initial grid...")
grid = generate_initial_grid(step_size)
# ==================== 第一次迭代:全局稀疏搜索 ====================
logger.info("Generating initial grid (global sparse search)...")
grid = generate_initial_grid(step_size, comp_ranges)
logger.info(f"Grid size: {len(grid)} comp combinations")
# 扩展到所有 helper lipid 和 route 组合
total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(routes)
logger.info(f"Total combinations: {total_combinations}")
# 创建 DataFrame
df = create_dataframe_from_formulations(
smiles, grid, HELPER_LIPID_OPTIONS, routes
)
# 预测
logger.info("Running predictions...")
df = predict_biodist(model, df, device, batch_size)
# 选择 top num_seeds 个种子点
seeds = select_top_k(df, organ, num_seeds, scoring_weights)
logger.info(f"Selected {len(seeds)} seeds for next iteration")
else:
# 后续迭代:围绕种子点精细化
logger.info(f"Generating refined grid around {len(seeds)} seeds...")
grid = generate_refined_grid(seeds, step_size, radius=2)
logger.info(f"Grid size: {len(grid)} comp combinations")
# 扩展到所有 helper lipid 和 route 组合
total_combinations = len(grid) * len(HELPER_LIPID_OPTIONS) * len(ROUTE_OPTIONS)
logger.info(f"Total combinations: {total_combinations}")
# 创建 DataFrame
df = create_dataframe_from_formulations(
smiles, grid, HELPER_LIPID_OPTIONS, ROUTE_OPTIONS
)
# 预测
logger.info("Running predictions...")
df = predict_biodist(model, df, device, batch_size)
# 选择 top-k
seeds = select_top_k(df, organ, top_k)
# ==================== 后续迭代:层级局部搜索 ====================
# 对每个种子点分别搜索,各自保留局部最优
logger.info(f"Hierarchical local search around {len(seeds)} seeds...")
all_local_best = []
for seed_idx, seed in enumerate(seeds):
# 为当前种子点生成邻域网格
local_grid = generate_single_seed_grid(seed, step_size, radius=2, comp_ranges=comp_ranges)
if len(local_grid) == 0:
# 如果没有新的网格点,保留原种子
all_local_best.append(seed)
continue
# 创建 DataFrame
df = create_dataframe_from_formulations(
smiles, local_grid, [seed.helper_lipid], [seed.route]
)
# 预测
df = predict_biodist(model, df, device, batch_size)
# 选择该种子邻域内的 top top_per_seed 个局部最优
local_top = select_top_k(df, organ, top_per_seed, scoring_weights)
all_local_best.extend(local_top)
if seed_idx == 0 or (seed_idx + 1) % 5 == 0:
logger.info(f" Seed {seed_idx + 1}/{len(seeds)}: local grid size={len(local_grid)}, "
f"local best score={_score(local_top[0]):.4f}")
# 更新种子为所有局部最优点(去重)
seen_keys = set()
unique_local_best = []
# 先按综合评分排序,确保保留最优的
all_local_best_sorted = sorted(all_local_best, key=_score, reverse=True)
for f in all_local_best_sorted:
key = f.unique_key()
if key not in seen_keys:
seen_keys.add(key)
unique_local_best.append(f)
seeds = unique_local_best
logger.info(f"Collected {len(seeds)} unique local best formulations (from {len(all_local_best)} candidates)")
# 显示当前最优
best = seeds[0]
logger.info(f"Current best Biodistribution_{organ}: {best.get_biodist(organ):.4f}")
best = max(seeds, key=_score)
logger.info(f"Current best score: {_score(best):.4f} (biodist_{organ}={best.get_biodist(organ):.4f})")
logger.info(f"Best formulation: {best.to_dict()}")
return seeds
# 最终去重、按综合评分排序并返回 top_k
seeds_sorted = sorted(seeds, key=_score, reverse=True)
# 去重:保留每个唯一配方中得分最高的(已排序,所以第一个出现的就是最高的)
seen_keys = set()
unique_results = []
for f in seeds_sorted:
key = f.unique_key()
if key not in seen_keys:
seen_keys.add(key)
unique_results.append(f)
logger.info(f"Final results: {len(unique_results)} unique formulations (from {len(seeds)} candidates)")
return unique_results[:top_k]
def format_results(formulations: List[Formulation], organ: str) -> pd.DataFrame:
@ -481,33 +944,54 @@ def main(
help="Output CSV path (optional)"
),
top_k: int = typer.Option(20, "--top-k", "-k", help="Number of top formulations to return"),
num_seeds: Optional[int] = typer.Option(None, "--num-seeds", "-n", help="Number of seed points from first iteration (default: top_k * 5)"),
top_per_seed: int = typer.Option(1, "--top-per-seed", "-t", help="Number of local best to keep per seed"),
step_sizes: Optional[str] = typer.Option(None, "--step-sizes", "-S", help="Comma-separated step sizes (e.g., '0.10,0.02,0.01')"),
batch_size: int = typer.Option(256, "--batch-size", "-b", help="Prediction batch size"),
device: str = typer.Option("cuda" if torch.cuda.is_available() else "cpu", "--device", "-d", help="Device"),
):
"""
配方优化程序寻找最大化目标器官 Biodistribution 的最优 LNP 配方
采用层级搜索策略
1. 第一次迭代全局稀疏搜索选择 top num_seeds 个分散的种子点
2. 后续迭代对每个种子点分别在其邻域内搜索各自保留 top_per_seed 个局部最优
3. 这样可以保持搜索的多样性避免结果集中在单一区域
示例:
python -m app.optimize --smiles "CC(C)..." --organ liver
python -m app.optimize -s "CC(C)..." -o spleen -k 10
python -m app.optimize -s "CC(C)..." -o spleen -k 10 -n 30 -t 2
python -m app.optimize -s "CC(C)..." -o liver -S "0.10,0.05,0.02"
"""
# 验证器官
if organ not in AVAILABLE_ORGANS:
logger.error(f"Invalid organ: {organ}. Available: {AVAILABLE_ORGANS}")
raise typer.Exit(1)
# 解析步长
parsed_step_sizes = None
if step_sizes:
try:
parsed_step_sizes = [float(s.strip()) for s in step_sizes.split(",")]
except ValueError:
logger.error(f"Invalid step sizes format: {step_sizes}")
raise typer.Exit(1)
# 加载模型
logger.info(f"Loading model from {model_path}")
device = torch.device(device)
model = load_model(model_path, device)
# 执行优化
# 执行优化(层级搜索策略)
results = optimize(
smiles=smiles,
organ=organ,
model=model,
device=device,
top_k=top_k,
num_seeds=num_seeds,
top_per_seed=top_per_seed,
step_sizes=parsed_step_sizes,
batch_size=batch_size,
)

Binary file not shown.

54
docker-compose-gpu.yml Normal file
View File

@ -0,0 +1,54 @@
services:
# FastAPI 后端服务
api:
build:
context: .
dockerfile: Dockerfile
target: api
container_name: lnp-api
environment:
- MODEL_PATH=/app/models/final/model.pt
volumes:
# 挂载模型目录以便更新模型
- ./models/final:/app/models/final:ro
- ./models/mpnn:/app/models/mpnn:ro
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/"]
interval: 30s
timeout: 10s
retries: 3
start_period: 60s
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
# Streamlit 前端服务
streamlit:
build:
context: .
dockerfile: Dockerfile
target: streamlit
container_name: lnp-streamlit
ports:
- "8501:8501"
environment:
- API_URL=http://api:8000
depends_on:
api:
condition: service_started
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8501/_stcore/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 30s
networks:
default:
name: lnp-network

49
docker-compose.yml Normal file
View File

@ -0,0 +1,49 @@
services:
# FastAPI 后端服务
api:
build:
context: .
dockerfile: Dockerfile
target: api
container_name: lnp-api
environment:
- MODEL_PATH=/app/models/final/model.pt
- DEVICE=cpu
volumes:
# 挂载模型目录以便更新模型
- ./models/final:/app/models/final:ro
- ./models/mpnn:/app/models/mpnn:ro
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/"]
interval: 30s
timeout: 10s
retries: 3
start_period: 60s
# Streamlit 前端服务
streamlit:
build:
context: .
dockerfile: Dockerfile
target: streamlit
container_name: lnp-streamlit
ports:
- "8501:8501"
environment:
- API_URL=http://api:8000
depends_on:
api:
condition: service_started
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8501/_stcore/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 30s
networks:
default:
name: lnp-network

View File

@ -0,0 +1,612 @@
"""
最终训练脚本3-fold Optuna 调参 + 全量数据训练
1. 使用全量数据做 3-fold StratifiedKFold Optuna 超参搜索
2. 固定最优超参后使用全量数据训练不使用 early-stopping
3. 使用 CosineAnnealingLR + 可选 SWA 防止过拟合
使用方法:
python -m lnp_ml.modeling.final_train_optuna_cv
或通过 Makefile:
make final_optuna DEVICE=cuda
"""
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from loguru import logger
import typer
try:
import optuna
from optuna.samplers import TPESampler
except ImportError:
optuna = None
TPESampler = None
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR
from lnp_ml.dataset import (
LNPDataset,
collate_fn,
process_dataframe,
TARGET_CLASSIFICATION_PDI,
TARGET_CLASSIFICATION_EE,
TARGET_TOXIC,
)
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
from lnp_ml.modeling.trainer_balanced import (
ClassWeights,
LossWeightsBalanced,
compute_class_weights_from_loader,
train_with_early_stopping,
train_fixed_epochs,
)
from lnp_ml.modeling.visualization import plot_multitask_loss_curves
# MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
app = typer.Typer()
# ============ CompositeStrata 复合分层标签(复用 nested_cv_optuna 的逻辑) ============
def build_composite_strata(
df: pd.DataFrame,
min_stratum_count: int = 5,
) -> Tuple[np.ndarray, Dict]:
"""
构建复合分层标签toxic × PDI × EE
Args:
df: 处理后的 DataFrame
min_stratum_count: 每个 stratum 最少样本数低于此值合并为 RARE
Returns:
(strata_array, strata_info)
"""
n = len(df)
strata_labels = []
for i in range(n):
# Toxic stratum
if TARGET_TOXIC in df.columns:
toxic_val = df[TARGET_TOXIC].iloc[i]
if pd.notna(toxic_val) and toxic_val >= 0:
toxic_str = str(int(toxic_val))
else:
toxic_str = "NA"
else:
toxic_str = "NA"
# PDI stratum
if all(col in df.columns for col in TARGET_CLASSIFICATION_PDI):
pdi_vals = df[TARGET_CLASSIFICATION_PDI].iloc[i].values
if pdi_vals.sum() > 0:
pdi_str = str(int(np.argmax(pdi_vals)))
else:
pdi_str = "NA"
else:
pdi_str = "NA"
# EE stratum
if all(col in df.columns for col in TARGET_CLASSIFICATION_EE):
ee_vals = df[TARGET_CLASSIFICATION_EE].iloc[i].values
if ee_vals.sum() > 0:
ee_str = str(int(np.argmax(ee_vals)))
else:
ee_str = "NA"
else:
ee_str = "NA"
strata_labels.append(f"T{toxic_str}|P{pdi_str}|E{ee_str}")
# 统计各 stratum 的样本数
unique_strata, counts = np.unique(strata_labels, return_counts=True)
strata_counts = dict(zip(unique_strata, counts))
# 将稀疏 strata 合并为 RARE
rare_strata = [s for s, c in strata_counts.items() if c < min_stratum_count]
final_labels = []
for label in strata_labels:
if label in rare_strata:
final_labels.append("RARE")
else:
final_labels.append(label)
# 编码为整数
unique_final, encoded = np.unique(final_labels, return_inverse=True)
strata_info = {
"original_strata_counts": strata_counts,
"rare_strata": rare_strata,
"final_strata": list(unique_final),
"final_strata_counts": dict(zip(*np.unique(final_labels, return_counts=True))),
"n_rare_merged": sum(strata_counts[s] for s in rare_strata) if rare_strata else 0,
}
logger.info(f"Built composite strata: {len(unique_final)} unique strata")
logger.info(f" Rare strata merged: {len(rare_strata)} types, {strata_info['n_rare_merged']} samples")
return encoded.astype(np.int64), strata_info
# ============ 模型创建 ============
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
if not model_paths:
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
return [str(p) for p in model_paths]
def create_model(
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
use_mpnn: bool = False,
mpnn_device: str = "cpu",
) -> Union[LNPModel, LNPModelWithoutMPNN]:
"""创建模型"""
if use_mpnn:
ensemble_paths = find_mpnn_ensemble_paths()
return LNPModel(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_ensemble_paths=ensemble_paths,
mpnn_device=mpnn_device,
)
else:
return LNPModelWithoutMPNN(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
)
# ============ 预训练权重加载 ============
def load_pretrain_weights_to_model(
model: Union[LNPModel, LNPModelWithoutMPNN],
pretrain_state_dict: Dict,
d_model: int,
pretrain_config: Dict,
load_delivery_head: bool = True,
) -> bool:
"""
加载预训练权重到模型
Returns:
是否成功加载
"""
if pretrain_config.get("d_model") != d_model:
logger.warning(
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
f"current={d_model}. Skipping pretrain loading."
)
return False
model.load_pretrain_weights(
pretrain_state_dict=pretrain_state_dict,
load_delivery_head=load_delivery_head,
strict=False,
)
return True
# ============ 3-fold Optuna 调参 ============
def run_optuna_cv(
full_dataset: LNPDataset,
strata: np.ndarray,
device: torch.device,
n_trials: int = 20,
epochs_per_trial: int = 30,
patience: int = 10,
batch_size: int = 32,
n_folds: int = 3,
use_mpnn: bool = False,
seed: int = 42,
study_path: Optional[Path] = None,
pretrain_state_dict: Optional[Dict] = None,
pretrain_config: Optional[Dict] = None,
load_delivery_head: bool = True,
) -> Tuple[Dict, int, optuna.Study]:
"""
使用全量数据做 3-fold CV Optuna 超参搜索
Args:
full_dataset: 完整数据集
strata: 每个样本的分层标签
device: 设备
n_trials: Optuna 试验数
epochs_per_trial: 每个试验的最大 epoch
patience: 早停耐心值
batch_size: 批次大小
n_folds: 折数
use_mpnn: 是否使用 MPNN
seed: 随机种子
study_path: 可选的 study 持久化路径
pretrain_state_dict: 预训练权重
pretrain_config: 预训练配置
load_delivery_head: 是否加载 delivery head 权重
Returns:
(best_params, epoch_mean, study)
"""
if optuna is None:
raise ImportError("Optuna not installed. Run: pip install optuna")
n_samples = len(full_dataset)
indices = np.arange(n_samples)
def objective(trial: optuna.Trial) -> float:
# 采样超参数
d_model = trial.suggest_categorical("d_model", [128, 256, 512])
num_heads = trial.suggest_categorical("num_heads", [4, 8])
n_attn_layers = trial.suggest_int("n_attn_layers", 2, 6)
fusion_strategy = trial.suggest_categorical(
"fusion_strategy", ["attention", "avg", "max"]
)
head_hidden_dim = trial.suggest_categorical("head_hidden_dim", [64, 128, 256])
dropout = trial.suggest_float("dropout", 0.05, 0.3)
lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
# 3-fold CV
cv = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
fold_val_losses = []
fold_best_epochs = []
for fold, (train_idx, val_idx) in enumerate(cv.split(indices, strata)):
# 创建 DataLoader
train_subset = Subset(full_dataset, train_idx.tolist())
val_subset = Subset(full_dataset, val_idx.tolist())
train_loader = DataLoader(
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
val_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 计算类权重
class_weights = compute_class_weights_from_loader(train_loader)
# 创建模型
model = create_model(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
use_mpnn=use_mpnn,
mpnn_device=device.type,
)
# 加载预训练权重
if pretrain_state_dict is not None and pretrain_config is not None:
load_pretrain_weights_to_model(
model, pretrain_state_dict, d_model, pretrain_config, load_delivery_head
)
# 训练(带早停)
result = train_with_early_stopping(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
lr=lr,
weight_decay=weight_decay,
epochs=epochs_per_trial,
patience=patience,
class_weights=class_weights,
)
fold_val_losses.append(result["best_val_loss"])
fold_best_epochs.append(result["best_epoch"])
# 记录 epoch_mean 到 trial
epoch_mean = int(round(np.mean(fold_best_epochs)))
trial.set_user_attr("epoch_mean", epoch_mean)
trial.set_user_attr("fold_best_epochs", fold_best_epochs)
trial.set_user_attr("fold_val_losses", fold_val_losses)
return np.mean(fold_val_losses)
# 创建 study
storage = None
if study_path is not None:
storage = f"sqlite:///{study_path}"
study = optuna.create_study(
direction="minimize",
sampler=TPESampler(seed=seed),
storage=storage,
study_name="final_optuna_cv",
load_if_exists=True,
)
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
best_params = study.best_trial.params
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
logger.info(f"Best trial: {study.best_trial.number}")
logger.info(f"Best val_loss: {study.best_trial.value:.4f}")
logger.info(f"Best params: {best_params}")
logger.info(f"Epoch mean from best trial: {epoch_mean}")
return best_params, epoch_mean, study
# ============ 主流程 ============
@app.command()
def main(
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
output_dir: Path = MODELS_DIR / "final_optuna",
# CV 参数
n_folds: int = 3,
min_stratum_count: int = 5,
seed: int = 42,
# Optuna 参数
n_trials: int = 20,
epochs_per_trial: int = 30,
patience: int = 10,
# 训练参数
batch_size: int = 32,
# 最终训练参数
use_swa: bool = False,
swa_start_ratio: float = 0.75,
# 预训练权重
init_from_pretrain: Optional[Path] = None,
load_delivery_head: bool = True,
# MPNN
use_mpnn: bool = False,
# 设备
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
最终训练3-fold Optuna 调参 + 全量数据训练
1. 使用全量数据做 3-fold StratifiedKFold Optuna 超参搜索
2. 固定最优超参后使用全量数据训练不使用 early-stopping
3. epoch 数使用 3-fold CV best trial epoch_mean
4. 使用 CosineAnnealingLR可选 SWA
使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重
"""
if optuna is None:
logger.error("Optuna not installed. Run: pip install optuna")
raise typer.Exit(1)
logger.info(f"Using device: {device}")
device = torch.device(device)
# 加载预训练权重(如果指定)
pretrain_state_dict = None
pretrain_config = None
if init_from_pretrain is not None:
if init_from_pretrain.exists():
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
pretrain_state_dict = checkpoint["model_state_dict"]
pretrain_config = checkpoint.get("config", {})
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")
else:
logger.warning(f"Pretrain checkpoint not found: {init_from_pretrain}, skipping")
# 创建输出目录
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Output directory: {output_dir}")
# 加载数据
logger.info(f"Loading data from {input_path}")
df = pd.read_csv(input_path)
logger.info(f"Loaded {len(df)} samples")
# 处理数据
logger.info("Processing dataframe...")
df = process_dataframe(df)
# 构建复合分层标签
logger.info("Building composite strata...")
strata, strata_info = build_composite_strata(df, min_stratum_count)
# 保存 strata 信息
with open(output_dir / "strata_info.json", "w") as f:
json.dump(strata_info, f, indent=2, default=str)
# 创建完整数据集
full_dataset = LNPDataset(df)
# 运行 Optuna 调参
logger.info(f"\nRunning {n_folds}-fold Optuna with {n_trials} trials...")
study_path = output_dir / "optuna_study.sqlite3"
best_params, epoch_mean, study = run_optuna_cv(
full_dataset=full_dataset,
strata=strata,
device=device,
n_trials=n_trials,
epochs_per_trial=epochs_per_trial,
patience=patience,
batch_size=batch_size,
n_folds=n_folds,
use_mpnn=use_mpnn,
seed=seed,
study_path=study_path,
pretrain_state_dict=pretrain_state_dict,
pretrain_config=pretrain_config,
load_delivery_head=load_delivery_head,
)
# 保存最佳参数
with open(output_dir / "best_params.json", "w") as f:
json.dump(best_params, f, indent=2)
with open(output_dir / "epoch_mean.json", "w") as f:
json.dump({"epoch_mean": epoch_mean}, f)
# 保存 Optuna 试验历史
trials_history = []
for trial in study.trials:
trials_history.append({
"number": trial.number,
"value": trial.value,
"params": trial.params,
"user_attrs": trial.user_attrs,
"state": str(trial.state),
})
with open(output_dir / "optuna_trials.json", "w") as f:
json.dump(trials_history, f, indent=2)
# 全量数据训练
logger.info(f"\n{'='*60}")
logger.info("FINAL TRAINING ON FULL DATA")
logger.info(f"{'='*60}")
logger.info(f"Using best params with epochs={epoch_mean}")
logger.info(f"SWA: {use_swa}")
# 创建全量 DataLoader
full_loader = DataLoader(
full_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
# 计算类权重
class_weights = compute_class_weights_from_loader(full_loader)
# 保存类权重信息
class_weights_info = {
"pdi": class_weights.pdi.tolist() if class_weights.pdi is not None else None,
"ee": class_weights.ee.tolist() if class_weights.ee is not None else None,
"toxic": class_weights.toxic.tolist() if class_weights.toxic is not None else None,
}
with open(output_dir / "class_weights.json", "w") as f:
json.dump(class_weights_info, f, indent=2)
# 创建模型
model = create_model(
d_model=best_params["d_model"],
num_heads=best_params["num_heads"],
n_attn_layers=best_params["n_attn_layers"],
fusion_strategy=best_params["fusion_strategy"],
head_hidden_dim=best_params["head_hidden_dim"],
dropout=best_params["dropout"],
use_mpnn=use_mpnn,
mpnn_device=device.type,
)
# 加载预训练权重
if pretrain_state_dict is not None and pretrain_config is not None:
loaded = load_pretrain_weights_to_model(
model, pretrain_state_dict, best_params["d_model"],
pretrain_config, load_delivery_head
)
if loaded:
logger.info("Loaded pretrain weights for final training")
# 打印模型信息
n_params_total = sum(p.numel() for p in model.parameters())
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
# 训练(固定 epoch不 early-stop
swa_start = int(epoch_mean * swa_start_ratio) if use_swa else None
train_result = train_fixed_epochs(
model=model,
train_loader=full_loader,
val_loader=None, # 全量训练,无验证集
device=device,
lr=best_params["lr"],
weight_decay=best_params["weight_decay"],
epochs=epoch_mean,
class_weights=class_weights,
use_cosine_annealing=True,
use_swa=use_swa,
swa_start_epoch=swa_start,
)
# 加载最终权重
model.load_state_dict(train_result["final_state"])
# 保存模型
config = {
"d_model": best_params["d_model"],
"num_heads": best_params["num_heads"],
"n_attn_layers": best_params["n_attn_layers"],
"fusion_strategy": best_params["fusion_strategy"],
"head_hidden_dim": best_params["head_hidden_dim"],
"dropout": best_params["dropout"],
"use_mpnn": use_mpnn,
}
torch.save({
"model_state_dict": train_result["final_state"],
"config": config,
"best_params": best_params,
"epoch_mean": epoch_mean,
"use_swa": use_swa,
}, output_dir / "model.pt")
logger.success(f"Saved model to {output_dir / 'model.pt'}")
# 保存训练历史
with open(output_dir / "history.json", "w") as f:
json.dump(train_result["history"], f, indent=2)
# 绘制损失曲线
if train_result["history"]["train"]:
try:
plot_multitask_loss_curves(
history=train_result["history"],
output_path=output_dir / "loss_curves.png",
title="Final Training Loss Curves",
)
logger.info(f"Saved loss curves to {output_dir / 'loss_curves.png'}")
except Exception as e:
logger.warning(f"Failed to plot loss curves: {e}")
# 打印最终信息
logger.info(f"\n{'='*60}")
logger.info("FINAL TRAINING COMPLETE")
logger.info(f"{'='*60}")
logger.info(f"Best params: {best_params}")
logger.info(f"Epochs trained: {epoch_mean}")
logger.info(f"Output directory: {output_dir}")
# 提示如何使用
logger.info("\n[How to use the trained model]")
logger.info(f" 1. Set environment variable: MODEL_PATH={output_dir / 'model.pt'}")
logger.info(" 2. Or copy to default location: cp {output_dir}/model.pt models/final/model.pt")
logger.info(" 3. Start API: make api")
if __name__ == "__main__":
app()

View File

@ -0,0 +1,774 @@
"""
嵌套交叉验证 + Optuna 超参调优
外层 5-fold StratifiedKFold20% test / 80% train
内层 3-fold StratifiedKFold 80% 上做 Optuna 超参搜索
使用方法:
python -m lnp_ml.modeling.nested_cv_optuna
或通过 Makefile:
make nested_cv_tune DEVICE=cuda
"""
import json
import math
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from loguru import logger
import typer
try:
import optuna
from optuna.samplers import TPESampler
except ImportError:
optuna = None
TPESampler = None
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR
from lnp_ml.dataset import (
LNPDataset,
collate_fn,
process_dataframe,
TARGET_CLASSIFICATION_PDI,
TARGET_CLASSIFICATION_EE,
TARGET_TOXIC,
)
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
from lnp_ml.modeling.trainer_balanced import (
ClassWeights,
LossWeightsBalanced,
compute_class_weights_from_loader,
train_with_early_stopping,
train_fixed_epochs,
validate_balanced,
)
# MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
app = typer.Typer()
# ============ CompositeStrata 复合分层标签 ============
def build_composite_strata(
df: pd.DataFrame,
min_stratum_count: int = 5,
) -> Tuple[np.ndarray, Dict]:
"""
构建复合分层标签toxic × PDI × EE
Args:
df: 处理后的 DataFrame
min_stratum_count: 每个 stratum 最少样本数低于此值合并为 RARE
Returns:
(strata_array, strata_info)
- strata_array: 每个样本的 stratum 编码整数
- strata_info: 统计信息
"""
n = len(df)
strata_labels = []
for i in range(n):
# Toxic stratum
if TARGET_TOXIC in df.columns:
toxic_val = df[TARGET_TOXIC].iloc[i]
if pd.notna(toxic_val) and toxic_val >= 0:
toxic_str = str(int(toxic_val))
else:
toxic_str = "NA"
else:
toxic_str = "NA"
# PDI stratum
if all(col in df.columns for col in TARGET_CLASSIFICATION_PDI):
pdi_vals = df[TARGET_CLASSIFICATION_PDI].iloc[i].values
if pdi_vals.sum() > 0:
pdi_str = str(int(np.argmax(pdi_vals)))
else:
pdi_str = "NA"
else:
pdi_str = "NA"
# EE stratum
if all(col in df.columns for col in TARGET_CLASSIFICATION_EE):
ee_vals = df[TARGET_CLASSIFICATION_EE].iloc[i].values
if ee_vals.sum() > 0:
ee_str = str(int(np.argmax(ee_vals)))
else:
ee_str = "NA"
else:
ee_str = "NA"
strata_labels.append(f"T{toxic_str}|P{pdi_str}|E{ee_str}")
# 统计各 stratum 的样本数
unique_strata, counts = np.unique(strata_labels, return_counts=True)
strata_counts = dict(zip(unique_strata, counts))
# 将稀疏 strata 合并为 RARE
rare_strata = [s for s, c in strata_counts.items() if c < min_stratum_count]
final_labels = []
for label in strata_labels:
if label in rare_strata:
final_labels.append("RARE")
else:
final_labels.append(label)
# 编码为整数
unique_final, encoded = np.unique(final_labels, return_inverse=True)
strata_info = {
"original_strata_counts": strata_counts,
"rare_strata": rare_strata,
"final_strata": list(unique_final),
"final_strata_counts": dict(zip(*np.unique(final_labels, return_counts=True))),
"n_rare_merged": sum(strata_counts[s] for s in rare_strata) if rare_strata else 0,
}
logger.info(f"Built composite strata: {len(unique_final)} unique strata")
logger.info(f" Rare strata merged: {len(rare_strata)} types, {strata_info['n_rare_merged']} samples")
return encoded.astype(np.int64), strata_info
# ============ 模型创建 ============
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
if not model_paths:
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
return [str(p) for p in model_paths]
def create_model(
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
use_mpnn: bool = False,
mpnn_device: str = "cpu",
) -> Union[LNPModel, LNPModelWithoutMPNN]:
"""创建模型"""
if use_mpnn:
ensemble_paths = find_mpnn_ensemble_paths()
return LNPModel(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_ensemble_paths=ensemble_paths,
mpnn_device=mpnn_device,
)
else:
return LNPModelWithoutMPNN(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
)
# ============ 评估指标 ============
def evaluate_on_test(
model: torch.nn.Module,
test_loader: DataLoader,
device: torch.device,
) -> Dict:
"""在测试集上评估模型"""
from scipy.special import rel_entr
from sklearn.metrics import (
mean_squared_error,
mean_absolute_error,
r2_score,
accuracy_score,
precision_score,
recall_score,
f1_score,
)
model.eval()
preds = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
}
targets = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
}
with torch.no_grad():
for batch in test_loader:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
tgts = batch["targets"]
masks = batch["mask"]
outputs = model(smiles, tabular)
# 收集预测和真实值
for task in ["size", "delivery"]:
if task in masks and masks[task].any():
m = masks[task]
key = task if task == "size" else "delivery"
preds[task].extend(outputs[key].squeeze(-1)[m].cpu().numpy().tolist())
targets[task].extend(tgts[key][m].cpu().numpy().tolist())
for task in ["pdi", "ee", "toxic"]:
if task in masks and masks[task].any():
m = masks[task]
preds[task].extend(outputs[task][m].argmax(dim=-1).cpu().numpy().tolist())
targets[task].extend(tgts[task][m].cpu().numpy().tolist())
if "biodist" in masks and masks["biodist"].any():
m = masks["biodist"]
preds["biodist"].extend(outputs["biodist"][m].cpu().numpy().tolist())
targets["biodist"].extend(tgts["biodist"][m].cpu().numpy().tolist())
# 计算指标
results = {}
# 回归任务
for task in ["size", "delivery"]:
if preds[task]:
p = np.array(preds[task])
t = np.array(targets[task])
results[task] = {
"n_samples": len(p),
"mse": float(mean_squared_error(t, p)),
"rmse": float(np.sqrt(mean_squared_error(t, p))),
"mae": float(mean_absolute_error(t, p)),
"r2": float(r2_score(t, p)),
}
# 分类任务
for task in ["pdi", "ee", "toxic"]:
if preds[task]:
p = np.array(preds[task])
t = np.array(targets[task])
results[task] = {
"n_samples": len(p),
"accuracy": float(accuracy_score(t, p)),
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
}
# 分布任务
if preds["biodist"]:
p = np.array(preds["biodist"])
t = np.array(targets["biodist"])
def kl_divergence(p_arr, q_arr, eps=1e-10):
p_arr = np.clip(p_arr, eps, 1.0)
q_arr = np.clip(q_arr, eps, 1.0)
return float(np.sum(rel_entr(p_arr, q_arr), axis=-1).mean())
def js_divergence(p_arr, q_arr, eps=1e-10):
p_arr = np.clip(p_arr, eps, 1.0)
q_arr = np.clip(q_arr, eps, 1.0)
m = 0.5 * (p_arr + q_arr)
return float(0.5 * (np.sum(rel_entr(p_arr, m), axis=-1) + np.sum(rel_entr(q_arr, m), axis=-1)).mean())
results["biodist"] = {
"n_samples": len(p),
"kl_divergence": kl_divergence(t, p),
"js_divergence": js_divergence(t, p),
}
return results
# ============ 预训练权重加载 ============
def load_pretrain_weights_to_model(
model: Union[LNPModel, LNPModelWithoutMPNN],
pretrain_state_dict: Dict,
d_model: int,
pretrain_config: Dict,
load_delivery_head: bool = True,
) -> bool:
"""
加载预训练权重到模型
Returns:
是否成功加载
"""
if pretrain_config.get("d_model") != d_model:
logger.warning(
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
f"current={d_model}. Skipping pretrain loading."
)
return False
model.load_pretrain_weights(
pretrain_state_dict=pretrain_state_dict,
load_delivery_head=load_delivery_head,
strict=False,
)
return True
# ============ 内层 Optuna 调参 ============
def run_inner_optuna(
full_dataset: LNPDataset,
inner_train_indices: np.ndarray,
strata: np.ndarray,
device: torch.device,
n_trials: int = 20,
epochs_per_trial: int = 30,
patience: int = 10,
batch_size: int = 32,
n_inner_folds: int = 3,
use_mpnn: bool = False,
seed: int = 42,
study_path: Optional[Path] = None,
pretrain_state_dict: Optional[Dict] = None,
pretrain_config: Optional[Dict] = None,
load_delivery_head: bool = True,
) -> Tuple[Dict, int, optuna.Study]:
"""
在内层数据上运行 Optuna 超参搜索
Args:
full_dataset: 完整数据集
inner_train_indices: 内层训练数据的索引相对于 full_dataset
strata: 每个样本的分层标签
device: 设备
n_trials: Optuna 试验数
epochs_per_trial: 每个试验的最大 epoch
patience: 早停耐心值
batch_size: 批次大小
n_inner_folds: 内层折数
use_mpnn: 是否使用 MPNN
seed: 随机种子
study_path: 可选的 study 持久化路径
pretrain_state_dict: 预训练权重
pretrain_config: 预训练配置
load_delivery_head: 是否加载 delivery head 权重
Returns:
(best_params, epoch_mean, study)
"""
if optuna is None:
raise ImportError("Optuna not installed. Run: pip install optuna")
inner_strata = strata[inner_train_indices]
def objective(trial: optuna.Trial) -> float:
# 采样超参数
d_model = trial.suggest_categorical("d_model", [128, 256, 512])
num_heads = trial.suggest_categorical("num_heads", [4, 8])
n_attn_layers = trial.suggest_int("n_attn_layers", 2, 6)
fusion_strategy = trial.suggest_categorical(
"fusion_strategy", ["attention", "avg", "max"]
)
head_hidden_dim = trial.suggest_categorical("head_hidden_dim", [64, 128, 256])
dropout = trial.suggest_float("dropout", 0.05, 0.3)
lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
# 内层 3-fold CV
inner_cv = StratifiedKFold(
n_splits=n_inner_folds, shuffle=True, random_state=seed
)
fold_val_losses = []
fold_best_epochs = []
for inner_fold, (inner_train_idx, inner_val_idx) in enumerate(
inner_cv.split(inner_train_indices, inner_strata)
):
# 获取实际的数据集索引
actual_train_idx = inner_train_indices[inner_train_idx]
actual_val_idx = inner_train_indices[inner_val_idx]
# 创建 DataLoader
train_subset = Subset(full_dataset, actual_train_idx.tolist())
val_subset = Subset(full_dataset, actual_val_idx.tolist())
train_loader = DataLoader(
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
val_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 计算类权重
class_weights = compute_class_weights_from_loader(train_loader)
# 创建模型
model = create_model(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
use_mpnn=use_mpnn,
mpnn_device=device.type,
)
# 加载预训练权重
if pretrain_state_dict is not None and pretrain_config is not None:
load_pretrain_weights_to_model(
model, pretrain_state_dict, d_model, pretrain_config, load_delivery_head
)
# 训练(带早停)
result = train_with_early_stopping(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
lr=lr,
weight_decay=weight_decay,
epochs=epochs_per_trial,
patience=patience,
class_weights=class_weights,
)
fold_val_losses.append(result["best_val_loss"])
fold_best_epochs.append(result["best_epoch"])
# 记录 epoch_mean 到 trial
epoch_mean = int(round(np.mean(fold_best_epochs)))
trial.set_user_attr("epoch_mean", epoch_mean)
trial.set_user_attr("fold_best_epochs", fold_best_epochs)
return np.mean(fold_val_losses)
# 创建 study
storage = None
if study_path is not None:
storage = f"sqlite:///{study_path}"
study = optuna.create_study(
direction="minimize",
sampler=TPESampler(seed=seed),
storage=storage,
study_name="inner_optuna",
load_if_exists=True,
)
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
best_params = study.best_trial.params
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
logger.info(f"Best trial: {study.best_trial.number}")
logger.info(f"Best val_loss: {study.best_trial.value:.4f}")
logger.info(f"Best params: {best_params}")
logger.info(f"Epoch mean: {epoch_mean}")
return best_params, epoch_mean, study
# ============ 主流程 ============
@app.command()
def main(
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
output_dir: Path = MODELS_DIR / "nested_cv",
# CV 参数
n_outer_folds: int = 5,
n_inner_folds: int = 3,
min_stratum_count: int = 5,
seed: int = 42,
# Optuna 参数
n_trials: int = 20,
epochs_per_trial: int = 30,
inner_patience: int = 10,
# 训练参数
batch_size: int = 32,
# 预训练权重
init_from_pretrain: Optional[Path] = None,
load_delivery_head: bool = True,
# MPNN
use_mpnn: bool = False,
# 设备
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
嵌套交叉验证 + Optuna 超参调优
外层 5-fold20% test / 80% train内层 3-fold Optuna 调参
外层训练不使用 early-stoppingepoch 数使用内层 best trial epoch_mean
使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重
"""
if optuna is None:
logger.error("Optuna not installed. Run: pip install optuna")
raise typer.Exit(1)
logger.info(f"Using device: {device}")
device = torch.device(device)
# 加载预训练权重(如果指定)
pretrain_state_dict = None
pretrain_config = None
if init_from_pretrain is not None:
if init_from_pretrain.exists():
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
pretrain_state_dict = checkpoint["model_state_dict"]
pretrain_config = checkpoint.get("config", {})
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")
else:
logger.warning(f"Pretrain checkpoint not found: {init_from_pretrain}, skipping")
# 创建输出目录(带时间戳)
run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = output_dir / run_name
run_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Output directory: {run_dir}")
# 加载数据
logger.info(f"Loading data from {input_path}")
df = pd.read_csv(input_path)
logger.info(f"Loaded {len(df)} samples")
# 处理数据
logger.info("Processing dataframe...")
df = process_dataframe(df)
# 构建复合分层标签
logger.info("Building composite strata...")
strata, strata_info = build_composite_strata(df, min_stratum_count)
# 保存 strata 信息
with open(run_dir / "strata_info.json", "w") as f:
json.dump(strata_info, f, indent=2, default=str)
# 创建完整数据集
full_dataset = LNPDataset(df)
n_samples = len(full_dataset)
# 外层 CV
outer_cv = StratifiedKFold(
n_splits=n_outer_folds, shuffle=True, random_state=seed
)
outer_results = []
for outer_fold, (outer_train_idx, outer_test_idx) in enumerate(
outer_cv.split(np.arange(n_samples), strata)
):
logger.info(f"\n{'='*60}")
logger.info(f"OUTER FOLD {outer_fold}")
logger.info(f"{'='*60}")
logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}")
fold_dir = run_dir / f"outer_fold_{outer_fold}"
fold_dir.mkdir(parents=True, exist_ok=True)
# 保存 split indices
splits = {
"outer_train_idx": outer_train_idx.tolist(),
"outer_test_idx": outer_test_idx.tolist(),
}
with open(fold_dir / "splits.json", "w") as f:
json.dump(splits, f)
# 内层 Optuna 调参
logger.info(f"\nRunning inner Optuna with {n_trials} trials...")
study_path = fold_dir / "optuna_study.sqlite3"
best_params, epoch_mean, study = run_inner_optuna(
full_dataset=full_dataset,
inner_train_indices=outer_train_idx,
strata=strata,
device=device,
n_trials=n_trials,
epochs_per_trial=epochs_per_trial,
patience=inner_patience,
batch_size=batch_size,
n_inner_folds=n_inner_folds,
use_mpnn=use_mpnn,
seed=seed + outer_fold,
study_path=study_path,
pretrain_state_dict=pretrain_state_dict,
pretrain_config=pretrain_config,
load_delivery_head=load_delivery_head,
)
# 保存最佳参数
with open(fold_dir / "best_params.json", "w") as f:
json.dump(best_params, f, indent=2)
with open(fold_dir / "epoch_mean.json", "w") as f:
json.dump({"epoch_mean": epoch_mean}, f)
# 外层训练(使用最优超参,固定 epoch 数,不 early-stop
logger.info(f"\nTraining outer fold with best params, epochs={epoch_mean}...")
# 创建 DataLoader
train_subset = Subset(full_dataset, outer_train_idx.tolist())
test_subset = Subset(full_dataset, outer_test_idx.tolist())
train_loader = DataLoader(
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
test_loader = DataLoader(
test_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 计算类权重
class_weights = compute_class_weights_from_loader(train_loader)
# 创建模型
model = create_model(
d_model=best_params["d_model"],
num_heads=best_params["num_heads"],
n_attn_layers=best_params["n_attn_layers"],
fusion_strategy=best_params["fusion_strategy"],
head_hidden_dim=best_params["head_hidden_dim"],
dropout=best_params["dropout"],
use_mpnn=use_mpnn,
mpnn_device=device.type,
)
# 加载预训练权重
if pretrain_state_dict is not None and pretrain_config is not None:
loaded = load_pretrain_weights_to_model(
model, pretrain_state_dict, best_params["d_model"],
pretrain_config, load_delivery_head
)
if loaded:
logger.info(f"Loaded pretrain weights for outer fold {outer_fold}")
# 训练(固定 epoch不 early-stop
train_result = train_fixed_epochs(
model=model,
train_loader=train_loader,
val_loader=None, # 外层不用验证集
device=device,
lr=best_params["lr"],
weight_decay=best_params["weight_decay"],
epochs=epoch_mean,
class_weights=class_weights,
use_cosine_annealing=True,
)
# 加载最终权重
model.load_state_dict(train_result["final_state"])
model = model.to(device)
# 保存模型
config = {
"d_model": best_params["d_model"],
"num_heads": best_params["num_heads"],
"n_attn_layers": best_params["n_attn_layers"],
"fusion_strategy": best_params["fusion_strategy"],
"head_hidden_dim": best_params["head_hidden_dim"],
"dropout": best_params["dropout"],
"use_mpnn": use_mpnn,
}
torch.save({
"model_state_dict": train_result["final_state"],
"config": config,
"epoch_mean": epoch_mean,
"best_params": best_params,
}, fold_dir / "model.pt")
# 保存训练历史
with open(fold_dir / "history.json", "w") as f:
json.dump(train_result["history"], f, indent=2)
# 在测试集上评估
logger.info("Evaluating on outer test set...")
test_metrics = evaluate_on_test(model, test_loader, device)
with open(fold_dir / "test_metrics.json", "w") as f:
json.dump(test_metrics, f, indent=2)
# 打印测试结果
logger.info(f"\nOuter Fold {outer_fold} Test Results:")
for task, metrics in test_metrics.items():
if "rmse" in metrics:
logger.info(f" {task}: RMSE={metrics['rmse']:.4f}, R²={metrics['r2']:.4f}")
elif "accuracy" in metrics:
logger.info(f" {task}: Acc={metrics['accuracy']:.4f}, F1={metrics['f1']:.4f}")
elif "kl_divergence" in metrics:
logger.info(f" {task}: KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}")
outer_results.append({
"fold": outer_fold,
"best_params": best_params,
"epoch_mean": epoch_mean,
"test_metrics": test_metrics,
})
# 汇总结果
logger.info("\n" + "=" * 60)
logger.info("NESTED CV COMPLETE")
logger.info("=" * 60)
# 计算汇总统计
summary = {"fold_results": outer_results}
# 对每个任务计算均值和标准差
tasks_with_metrics = {}
for result in outer_results:
for task, metrics in result["test_metrics"].items():
if task not in tasks_with_metrics:
tasks_with_metrics[task] = {k: [] for k in metrics.keys() if k != "n_samples"}
for k, v in metrics.items():
if k != "n_samples":
tasks_with_metrics[task][k].append(v)
summary["summary_stats"] = {}
for task, metrics_dict in tasks_with_metrics.items():
summary["summary_stats"][task] = {}
for metric_name, values in metrics_dict.items():
summary["summary_stats"][task][f"{metric_name}_mean"] = float(np.mean(values))
summary["summary_stats"][task][f"{metric_name}_std"] = float(np.std(values))
# 打印汇总
logger.info("\n[Summary Statistics]")
for task, stats in summary["summary_stats"].items():
if "rmse_mean" in stats:
logger.info(
f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, "
f"R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}"
)
elif "accuracy_mean" in stats:
logger.info(
f" {task}: Acc={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}, "
f"F1={stats['f1_mean']:.4f}±{stats['f1_std']:.4f}"
)
elif "kl_divergence_mean" in stats:
logger.info(
f" {task}: KL={stats['kl_divergence_mean']:.4f}±{stats['kl_divergence_std']:.4f}, "
f"JS={stats['js_divergence_mean']:.4f}±{stats['js_divergence_std']:.4f}"
)
# 保存汇总
with open(run_dir / "summary.json", "w") as f:
json.dump(summary, f, indent=2)
logger.success(f"\nAll results saved to {run_dir}")
if __name__ == "__main__":
app()

View File

@ -0,0 +1,474 @@
"""带类权重的训练器:处理分类任务的数据不均衡问题"""
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
@dataclass
class ClassWeights:
"""分类任务的类权重"""
pdi: Optional[torch.Tensor] = None # [4] for 4 PDI classes
ee: Optional[torch.Tensor] = None # [3] for 3 EE classes
toxic: Optional[torch.Tensor] = None # [2] for binary toxic
@dataclass
class LossWeightsBalanced:
"""各任务的损失权重(与 trainer.py 的 LossWeights 兼容)"""
size: float = 1.0
pdi: float = 1.0
ee: float = 1.0
delivery: float = 1.0
biodist: float = 1.0
toxic: float = 1.0
def compute_class_weights_from_loader(
loader: DataLoader,
n_pdi_classes: int = 4,
n_ee_classes: int = 3,
n_toxic_classes: int = 2,
smoothing: float = 0.1,
) -> ClassWeights:
"""
DataLoader 统计类别频次并计算类权重
使用 inverse frequency 方式weight_c = N / (n_classes * count_c)
smoothing 避免极端权重
Args:
loader: DataLoader需要遍历一次
n_pdi_classes: PDI 类别数
n_ee_classes: EE 类别数
n_toxic_classes: toxic 类别数
smoothing: 平滑系数防止除零和极端权重
Returns:
ClassWeights 对象包含各分类任务的类权重张量
"""
pdi_counts = torch.zeros(n_pdi_classes)
ee_counts = torch.zeros(n_ee_classes)
toxic_counts = torch.zeros(n_toxic_classes)
for batch in loader:
targets = batch["targets"]
mask = batch["mask"]
# PDI
if "pdi" in targets and "pdi" in mask:
m = mask["pdi"]
if m.any():
labels = targets["pdi"][m]
for c in range(n_pdi_classes):
pdi_counts[c] += (labels == c).sum().item()
# EE
if "ee" in targets and "ee" in mask:
m = mask["ee"]
if m.any():
labels = targets["ee"][m]
for c in range(n_ee_classes):
ee_counts[c] += (labels == c).sum().item()
# Toxic
if "toxic" in targets and "toxic" in mask:
m = mask["toxic"]
if m.any():
labels = targets["toxic"][m]
for c in range(n_toxic_classes):
toxic_counts[c] += (labels == c).sum().item()
def counts_to_weights(counts: torch.Tensor, n_classes: int) -> Optional[torch.Tensor]:
"""将计数转换为类权重"""
total = counts.sum().item()
if total == 0:
return None
# Inverse frequency with smoothing
counts = counts + smoothing
weights = total / (n_classes * counts)
# Normalize to mean=1
weights = weights / weights.mean()
return weights
return ClassWeights(
pdi=counts_to_weights(pdi_counts, n_pdi_classes),
ee=counts_to_weights(ee_counts, n_ee_classes),
toxic=counts_to_weights(toxic_counts, n_toxic_classes),
)
def compute_multitask_loss_balanced(
outputs: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
mask: Dict[str, torch.Tensor],
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
计算带类权重的多任务损失
Args:
outputs: 模型输出
targets: 真实标签
mask: 有效样本掩码
task_weights: 各任务权重
class_weights: 分类任务的类权重
Returns:
(total_loss, loss_dict) 总损失和各任务损失
"""
task_weights = task_weights or LossWeightsBalanced()
class_weights = class_weights or ClassWeights()
losses = {}
device = next(iter(outputs.values())).device
total_loss = torch.tensor(0.0, device=device)
# size: MSE loss回归任务不需要类权重
if "size" in targets and mask["size"].any():
m = mask["size"]
pred = outputs["size"][m].squeeze(-1)
tgt = targets["size"][m]
losses["size"] = F.mse_loss(pred, tgt)
total_loss = total_loss + task_weights.size * losses["size"]
# delivery: MSE loss回归任务不需要类权重
if "delivery" in targets and mask["delivery"].any():
m = mask["delivery"]
pred = outputs["delivery"][m].squeeze(-1)
tgt = targets["delivery"][m]
losses["delivery"] = F.mse_loss(pred, tgt)
total_loss = total_loss + task_weights.delivery * losses["delivery"]
# pdi: CrossEntropy with class weights
if "pdi" in targets and mask["pdi"].any():
m = mask["pdi"]
pred = outputs["pdi"][m]
tgt = targets["pdi"][m]
weight = class_weights.pdi.to(device) if class_weights.pdi is not None else None
losses["pdi"] = F.cross_entropy(pred, tgt, weight=weight)
total_loss = total_loss + task_weights.pdi * losses["pdi"]
# ee: CrossEntropy with class weights
if "ee" in targets and mask["ee"].any():
m = mask["ee"]
pred = outputs["ee"][m]
tgt = targets["ee"][m]
weight = class_weights.ee.to(device) if class_weights.ee is not None else None
losses["ee"] = F.cross_entropy(pred, tgt, weight=weight)
total_loss = total_loss + task_weights.ee * losses["ee"]
# toxic: CrossEntropy with class weights
if "toxic" in targets and mask["toxic"].any():
m = mask["toxic"]
pred = outputs["toxic"][m]
tgt = targets["toxic"][m]
weight = class_weights.toxic.to(device) if class_weights.toxic is not None else None
losses["toxic"] = F.cross_entropy(pred, tgt, weight=weight)
total_loss = total_loss + task_weights.toxic * losses["toxic"]
# biodist: KL divergence分布任务不需要类权重
if "biodist" in targets and mask["biodist"].any():
m = mask["biodist"]
pred = outputs["biodist"][m]
tgt = targets["biodist"][m]
losses["biodist"] = F.kl_div(
pred.log().clamp(min=-100),
tgt,
reduction="batchmean",
)
total_loss = total_loss + task_weights.biodist * losses["biodist"]
return total_loss, losses
def train_epoch_balanced(
model: nn.Module,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device,
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
) -> Dict[str, float]:
"""带类权重的训练一个 epoch"""
model.train()
total_loss = 0.0
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
n_batches = 0
for batch in tqdm(loader, desc="Training", leave=False):
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = {k: v.to(device) for k, v in batch["targets"].items()}
mask = {k: v.to(device) for k, v in batch["mask"].items()}
optimizer.zero_grad()
outputs = model(smiles, tabular)
loss, losses = compute_multitask_loss_balanced(
outputs, targets, mask, task_weights, class_weights
)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
for k, v in losses.items():
task_losses[k] += v.item()
n_batches += 1
return {
"loss": total_loss / n_batches,
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
}
@torch.no_grad()
def validate_balanced(
model: nn.Module,
loader: DataLoader,
device: torch.device,
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
) -> Dict[str, float]:
"""带类权重的验证"""
model.eval()
total_loss = 0.0
task_losses = {k: 0.0 for k in ["size", "pdi", "ee", "delivery", "biodist", "toxic"]}
n_batches = 0
# 用于计算准确率
correct = {k: 0 for k in ["pdi", "ee", "toxic"]}
total = {k: 0 for k in ["pdi", "ee", "toxic"]}
for batch in tqdm(loader, desc="Validating", leave=False):
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = {k: v.to(device) for k, v in batch["targets"].items()}
mask = {k: v.to(device) for k, v in batch["mask"].items()}
outputs = model(smiles, tabular)
loss, losses = compute_multitask_loss_balanced(
outputs, targets, mask, task_weights, class_weights
)
total_loss += loss.item()
for k, v in losses.items():
task_losses[k] += v.item()
n_batches += 1
# 计算分类准确率
for k in ["pdi", "ee", "toxic"]:
if k in targets and mask[k].any():
m = mask[k]
pred = outputs[k][m].argmax(dim=-1)
tgt = targets[k][m]
correct[k] += (pred == tgt).sum().item()
total[k] += m.sum().item()
metrics = {
"loss": total_loss / n_batches,
**{f"loss_{k}": v / n_batches for k, v in task_losses.items() if v > 0},
}
# 添加准确率
for k in ["pdi", "ee", "toxic"]:
if total[k] > 0:
metrics[f"acc_{k}"] = correct[k] / total[k]
return metrics
class EarlyStoppingBalanced:
"""早停机制(与 trainer.py 的 EarlyStopping 兼容)"""
def __init__(self, patience: int = 10, min_delta: float = 0.0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float("inf")
self.best_epoch = 0
self.should_stop = False
def __call__(self, val_loss: float, epoch: int = 0) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.best_epoch = epoch
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
return self.should_stop
def get_best_epoch(self) -> int:
"""获取最佳 epoch1-indexed"""
return self.best_epoch + 1
def train_with_early_stopping(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 100,
patience: int = 15,
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
) -> Dict:
"""
带早停的完整训练流程
Returns:
Dict with keys: history, best_val_loss, best_epoch, best_state
"""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
early_stopping = EarlyStoppingBalanced(patience=patience)
history = {"train": [], "val": []}
best_val_loss = float("inf")
best_state = None
for epoch in range(epochs):
# Train
train_metrics = train_epoch_balanced(
model, train_loader, optimizer, device, task_weights, class_weights
)
# Validate
val_metrics = validate_balanced(
model, val_loader, device, task_weights, class_weights
)
history["train"].append(train_metrics)
history["val"].append(val_metrics)
# Learning rate scheduling
scheduler.step(val_metrics["loss"])
# Save best model
if val_metrics["loss"] < best_val_loss:
best_val_loss = val_metrics["loss"]
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
# Early stopping
if early_stopping(val_metrics["loss"], epoch):
break
# Restore best model
if best_state is not None:
model.load_state_dict(best_state)
return {
"history": history,
"best_val_loss": best_val_loss,
"best_epoch": early_stopping.get_best_epoch(),
"best_state": best_state,
"epochs_trained": len(history["train"]),
}
def train_fixed_epochs(
model: nn.Module,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
device: torch.device,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 50,
task_weights: Optional[LossWeightsBalanced] = None,
class_weights: Optional[ClassWeights] = None,
use_cosine_annealing: bool = True,
use_swa: bool = False,
swa_start_epoch: Optional[int] = None,
) -> Dict:
"""
固定 epoch 数的训练不使用 early stopping
用于外层 CV 训练和最终训练
Args:
model: 模型
train_loader: 训练数据
val_loader: 验证数据可选仅用于监控
device: 设备
lr: 学习率
weight_decay: 权重衰减
epochs: 训练轮数
task_weights: 任务权重
class_weights: 类权重
use_cosine_annealing: 是否使用 CosineAnnealingLR
use_swa: 是否使用 SWA
swa_start_epoch: SWA 开始的 epoch默认为 epochs * 0.75
Returns:
Dict with keys: history, final_state
"""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
if use_cosine_annealing:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
else:
scheduler = None
# SWA setup
swa_model = None
swa_scheduler = None
if use_swa:
from torch.optim.swa_utils import AveragedModel, SWALR
swa_model = AveragedModel(model)
swa_start = swa_start_epoch or int(epochs * 0.75)
swa_scheduler = SWALR(optimizer, swa_lr=lr * 0.1)
history = {"train": [], "val": []}
for epoch in range(epochs):
# Train
train_metrics = train_epoch_balanced(
model, train_loader, optimizer, device, task_weights, class_weights
)
history["train"].append(train_metrics)
# Validate (optional)
if val_loader is not None:
val_metrics = validate_balanced(
model, val_loader, device, task_weights, class_weights
)
history["val"].append(val_metrics)
# Scheduler step
if use_swa and epoch >= swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
elif scheduler is not None:
scheduler.step()
# Finalize SWA
final_state = None
if use_swa and swa_model is not None:
# Update batch normalization statistics
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
final_state = {k: v.cpu().clone() for k, v in swa_model.module.state_dict().items()}
else:
final_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
return {
"history": history,
"final_state": final_state,
"epochs_trained": epochs,
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,10 @@
{
"d_model": 512,
"num_heads": 8,
"n_attn_layers": 5,
"fusion_strategy": "attention",
"head_hidden_dim": 128,
"dropout": 0.14666736316838325,
"lr": 0.0001295888795454003,
"weight_decay": 7.732380983243132e-05
}

View File

@ -0,0 +1 @@
{"epoch_mean": 13}

View File

@ -0,0 +1,122 @@
{
"train": [
{
"loss": 9.688567074862393,
"loss_size": 4.525775753638961,
"loss_pdi": 1.2179414359006016,
"loss_ee": 1.091066924008456,
"loss_delivery": 1.1003740294413134,
"loss_biodist": 1.045865768736059,
"loss_toxic": 0.7075429206544702
},
{
"loss": 4.477882081812078,
"loss_size": 0.30304074490612204,
"loss_pdi": 0.9405507716265592,
"loss_ee": 0.9982529336755926,
"loss_delivery": 1.1088530285791918,
"loss_biodist": 0.6725774082270536,
"loss_toxic": 0.45460730520161713
},
{
"loss": 3.7592489285902544,
"loss_size": 0.3136644837531177,
"loss_pdi": 0.7663570263169028,
"loss_ee": 0.9424018914049322,
"loss_delivery": 0.9587065902623263,
"loss_biodist": 0.5116219466382806,
"loss_toxic": 0.2664969251914458
},
{
"loss": 3.2862942435524682,
"loss_size": 0.25651484592394397,
"loss_pdi": 0.729118591005152,
"loss_ee": 0.8862769928845492,
"loss_delivery": 0.8572801744396036,
"loss_biodist": 0.42606663974848663,
"loss_toxic": 0.13103695366192947
},
{
"loss": 3.032014153220437,
"loss_size": 0.2605769471688704,
"loss_pdi": 0.6875471039251848,
"loss_ee": 0.8450987447391857,
"loss_delivery": 0.7942117723551664,
"loss_biodist": 0.3333369114182212,
"loss_toxic": 0.11124271154403687
},
{
"loss": 2.747604175047441,
"loss_size": 0.23051185499538074,
"loss_pdi": 0.6304752420295369,
"loss_ee": 0.8027611483227123,
"loss_delivery": 0.6954045072197914,
"loss_biodist": 0.31399699503725226,
"loss_toxic": 0.07445447621020404
},
{
"loss": 2.496532602743669,
"loss_size": 0.21782836656678806,
"loss_pdi": 0.6259297322143208,
"loss_ee": 0.7789664485237815,
"loss_delivery": 0.5751941861076788,
"loss_biodist": 0.25595076788555493,
"loss_toxic": 0.04266304531219331
},
{
"loss": 2.474623289975253,
"loss_size": 0.2284239645708691,
"loss_pdi": 0.6423451900482178,
"loss_ee": 0.7340371066873724,
"loss_delivery": 0.5827173766764727,
"loss_biodist": 0.24910570003769614,
"loss_toxic": 0.03799399394880642
},
{
"loss": 2.3557864644310693,
"loss_size": 0.197183218869296,
"loss_pdi": 0.5410721220753409,
"loss_ee": 0.7328030331568285,
"loss_delivery": 0.6107787185094573,
"loss_biodist": 0.2399107339707288,
"loss_toxic": 0.03403859864920378
},
{
"loss": 2.187597805803472,
"loss_size": 0.19923549348657782,
"loss_pdi": 0.514803715727546,
"loss_ee": 0.7146885395050049,
"loss_delivery": 0.49052428555759514,
"loss_biodist": 0.22473187744617462,
"loss_toxic": 0.043613933196121994
},
{
"loss": 2.1395224874669854,
"loss_size": 0.20194762064652008,
"loss_pdi": 0.5207281329415061,
"loss_ee": 0.6976469484242526,
"loss_delivery": 0.4825741987336766,
"loss_biodist": 0.21454968913034958,
"loss_toxic": 0.022075910629196602
},
{
"loss": 2.0708962462165137,
"loss_size": 0.18447093801064926,
"loss_pdi": 0.52463944662701,
"loss_ee": 0.6666911081834273,
"loss_delivery": 0.4482487521388314,
"loss_biodist": 0.21374160457741131,
"loss_toxic": 0.033104426227509975
},
{
"loss": 2.065277782353488,
"loss_size": 0.17790686813267795,
"loss_pdi": 0.5256167894059961,
"loss_ee": 0.6798818870024248,
"loss_delivery": 0.43902007693594153,
"loss_biodist": 0.21479454501108688,
"loss_toxic": 0.02805758648636666
}
],
"val": []
}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04f3103a630de23aeb971629aa769e1a9cdb3c5247c193eefb808c6a4e17b9cc
size 133133308

View File

@ -0,0 +1 @@
{"outer_train_idx": [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 32, 33, 34, 35, 37, 38, 39, 40, 41, 43, 44, 45, 46, 47, 48, 49, 50, 52, 53, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 77, 78, 79, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 102, 103, 104, 105, 106, 107, 108, 109, 110, 112, 113, 114, 115, 117, 118, 119, 121, 123, 126, 127, 128, 129, 131, 132, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 146, 147, 148, 149, 150, 152, 153, 155, 156, 158, 159, 163, 167, 169, 170, 171, 172, 173, 175, 176, 177, 178, 179, 180, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 207, 208, 211, 214, 215, 216, 217, 218, 219, 221, 222, 223, 224, 227, 228, 229, 230, 231, 233, 234, 236, 237, 238, 239, 243, 244, 245, 247, 248, 249, 251, 252, 255, 256, 257, 258, 260, 261, 262, 263, 264, 265, 266, 268, 269, 270, 271, 272, 273, 274, 275, 276, 279, 282, 283, 284, 286, 287, 288, 290, 291, 292, 293, 294, 295, 297, 298, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 319, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 335, 337, 338, 339, 340, 341, 343, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 358, 359, 360, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 379, 380, 381, 382, 383, 384, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 398, 399, 400, 402, 403, 404, 408, 409, 410, 411, 412, 413, 415, 416, 417, 418, 419, 420, 421, 422, 423, 425, 426, 428, 429, 432, 433], "outer_test_idx": [7, 12, 13, 25, 31, 36, 42, 51, 54, 55, 60, 72, 76, 80, 101, 111, 116, 120, 122, 124, 125, 130, 133, 145, 151, 154, 157, 160, 161, 162, 164, 165, 166, 168, 174, 181, 206, 209, 210, 212, 213, 220, 225, 226, 232, 235, 240, 241, 242, 246, 250, 253, 254, 259, 267, 277, 278, 280, 281, 285, 289, 296, 299, 300, 318, 320, 321, 334, 336, 342, 344, 356, 357, 361, 377, 378, 385, 397, 401, 405, 406, 407, 414, 424, 427, 430, 431]}

View File

@ -0,0 +1,42 @@
{
"size": {
"n_samples": 87,
"mse": 0.26366087871128757,
"rmse": 0.5134791901443403,
"mae": 0.25157783223294666,
"r2": 0.21208517006410577
},
"delivery": {
"n_samples": 61,
"mse": 0.40443344562739025,
"rmse": 0.63595082013265,
"mae": 0.3928790920429298,
"r2": 0.2300258531372983
},
"pdi": {
"n_samples": 87,
"accuracy": 0.7241379310344828,
"precision": 0.35141509433962265,
"recall": 0.35351966873706003,
"f1": 0.348405985686402
},
"ee": {
"n_samples": 87,
"accuracy": 0.6666666666666666,
"precision": 0.6188811188811189,
"recall": 0.6375291375291375,
"f1": 0.6217948717948718
},
"toxic": {
"n_samples": 62,
"accuracy": 0.967741935483871,
"precision": 0.8,
"recall": 0.9830508474576272,
"f1": 0.8663793103448275
},
"biodist": {
"n_samples": 61,
"kl_divergence": 0.14776465556036145,
"js_divergence": 0.03926150329301917
}
}

View File

@ -0,0 +1,10 @@
{
"d_model": 512,
"num_heads": 4,
"n_attn_layers": 2,
"fusion_strategy": "avg",
"head_hidden_dim": 64,
"dropout": 0.05188345993471756,
"lr": 4.21188892865021e-05,
"weight_decay": 4.086499445232577e-05
}

View File

@ -0,0 +1 @@
{"epoch_mean": 23}

View File

@ -0,0 +1,212 @@
{
"train": [
{
"loss": 19.564044865694914,
"loss_size": 14.219423033974387,
"loss_pdi": 1.3700515248558738,
"loss_ee": 1.0948690609498457,
"loss_delivery": 0.925118706443093,
"loss_biodist": 1.2897993542931296,
"loss_toxic": 0.6647828817367554
},
{
"loss": 9.569591695612127,
"loss_size": 4.443219217387113,
"loss_pdi": 1.3034031066027554,
"loss_ee": 1.0691871643066406,
"loss_delivery": 0.986666976050897,
"loss_biodist": 1.1807570999318904,
"loss_toxic": 0.5863581760363146
},
{
"loss": 5.09315148266879,
"loss_size": 0.4028705805540085,
"loss_pdi": 1.2596685236150569,
"loss_ee": 1.0527758598327637,
"loss_delivery": 0.8668525652451948,
"loss_biodist": 1.0210744521834634,
"loss_toxic": 0.48990951072085986
},
{
"loss": 4.664039200002497,
"loss_size": 0.26930826157331467,
"loss_pdi": 1.1607397361235186,
"loss_ee": 1.0082263296300715,
"loss_delivery": 0.9969787990505045,
"loss_biodist": 0.8467015407302163,
"loss_toxic": 0.3820846405896274
},
{
"loss": 4.224615703929555,
"loss_size": 0.28240104629234836,
"loss_pdi": 1.048582375049591,
"loss_ee": 0.9691753116520968,
"loss_delivery": 0.8995198593898253,
"loss_biodist": 0.7173566547307101,
"loss_toxic": 0.3075804940678857
},
{
"loss": 3.810542041605169,
"loss_size": 0.23119159516963092,
"loss_pdi": 0.9792777408253063,
"loss_ee": 0.9330252029679038,
"loss_delivery": 0.8217983476140283,
"loss_biodist": 0.6248252906582572,
"loss_toxic": 0.22042379054156216
},
{
"loss": 3.5365114862268623,
"loss_size": 0.21511441333727402,
"loss_pdi": 0.9201341759074818,
"loss_ee": 0.9007513360543684,
"loss_delivery": 0.8033261312679811,
"loss_biodist": 0.5451555441726338,
"loss_toxic": 0.15202991325746884
},
{
"loss": 3.405807386745106,
"loss_size": 0.20161619240587408,
"loss_pdi": 0.8445044376633384,
"loss_ee": 0.8931786905635487,
"loss_delivery": 0.831847377798774,
"loss_biodist": 0.49607704173434863,
"loss_toxic": 0.13858359239318155
},
{
"loss": 3.103480577468872,
"loss_size": 0.18636523119427942,
"loss_pdi": 0.8065886064009233,
"loss_ee": 0.8574360067194159,
"loss_delivery": 0.6954484595493837,
"loss_biodist": 0.45561750639568677,
"loss_toxic": 0.10202483257109468
},
{
"loss": 3.0312788052992388,
"loss_size": 0.19699390774423425,
"loss_pdi": 0.7626179348338734,
"loss_ee": 0.8408515670082786,
"loss_delivery": 0.7123599113388495,
"loss_biodist": 0.41937256130305206,
"loss_toxic": 0.09908294406804172
},
{
"loss": 2.8196144104003906,
"loss_size": 0.1769183948636055,
"loss_pdi": 0.717121189290827,
"loss_ee": 0.8330178802663629,
"loss_delivery": 0.6163532761010256,
"loss_biodist": 0.3913883079182018,
"loss_toxic": 0.08481534672054378
},
{
"loss": 2.7494440295479516,
"loss_size": 0.17335935140197928,
"loss_pdi": 0.7235767028548501,
"loss_ee": 0.812805939804424,
"loss_delivery": 0.5859575948931954,
"loss_biodist": 0.38201813806187024,
"loss_toxic": 0.07172633301128041
},
{
"loss": 2.6807472705841064,
"loss_size": 0.20867897028272803,
"loss_pdi": 0.6913418119603937,
"loss_ee": 0.8103034550493414,
"loss_delivery": 0.5472286993807013,
"loss_biodist": 0.34767021103338763,
"loss_toxic": 0.07552417537028139
},
{
"loss": 2.5547089793465356,
"loss_size": 0.17158158326690848,
"loss_pdi": 0.6350900205698881,
"loss_ee": 0.7878076163205233,
"loss_delivery": 0.5484779531305487,
"loss_biodist": 0.34313417022878473,
"loss_toxic": 0.06861767376011069
},
{
"loss": 2.557967185974121,
"loss_size": 0.14963022822683508,
"loss_pdi": 0.6706653941761364,
"loss_ee": 0.7864666526967828,
"loss_delivery": 0.5519421967593107,
"loss_biodist": 0.3387457403269681,
"loss_toxic": 0.06051696803082119
},
{
"loss": 2.5365889939394863,
"loss_size": 0.17580546641891653,
"loss_pdi": 0.6637366847558455,
"loss_ee": 0.7831195484508168,
"loss_delivery": 0.5293790704824708,
"loss_biodist": 0.3297005011276765,
"loss_toxic": 0.054847732186317444
},
{
"loss": 2.51393402706493,
"loss_size": 0.1777868609536778,
"loss_pdi": 0.6462894515557722,
"loss_ee": 0.7651460604234175,
"loss_delivery": 0.5493429905988954,
"loss_biodist": 0.317649554122578,
"loss_toxic": 0.05771913768892938
},
{
"loss": 2.425347761674361,
"loss_size": 0.14695054224946283,
"loss_pdi": 0.6508165923031893,
"loss_ee": 0.7596850720318881,
"loss_delivery": 0.49562479284676636,
"loss_biodist": 0.31901306862180884,
"loss_toxic": 0.053257653659040276
},
{
"loss": 2.450036742470481,
"loss_size": 0.15424400500275873,
"loss_pdi": 0.655444394458424,
"loss_ee": 0.7591585137627341,
"loss_delivery": 0.5164471390572462,
"loss_biodist": 0.3094088543545116,
"loss_toxic": 0.05533381686969237
},
{
"loss": 2.3860875476490366,
"loss_size": 0.1556895158507607,
"loss_pdi": 0.6400747786868702,
"loss_ee": 0.7600220929492604,
"loss_delivery": 0.46673478253863077,
"loss_biodist": 0.3111781979149038,
"loss_toxic": 0.05238820222968405
},
{
"loss": 2.3954180804165928,
"loss_size": 0.1462622725150802,
"loss_pdi": 0.619796259836717,
"loss_ee": 0.7609411152926359,
"loss_delivery": 0.5080444128675894,
"loss_biodist": 0.30634408511898736,
"loss_toxic": 0.05402998389168219
},
{
"loss": 2.3838103034279565,
"loss_size": 0.13035081394694067,
"loss_pdi": 0.6347457062114369,
"loss_ee": 0.7560921203006398,
"loss_delivery": 0.4973590536551042,
"loss_biodist": 0.3073451383547349,
"loss_toxic": 0.05791747265241363
},
{
"loss": 2.4303664727644487,
"loss_size": 0.1789045144211162,
"loss_pdi": 0.631924033164978,
"loss_ee": 0.7499593008648265,
"loss_delivery": 0.502868037332188,
"loss_biodist": 0.3101512383330952,
"loss_toxic": 0.05655933002179319
}
],
"val": []
}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b1c9190b09f7ab586a7d5190b74452f4be67c5d2fb753aae202f9b723523ef00
size 55606866

View File

@ -0,0 +1 @@
{"outer_train_idx": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 47, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 60, 61, 62, 64, 65, 70, 71, 72, 73, 74, 75, 76, 77, 79, 80, 81, 82, 84, 87, 89, 90, 91, 92, 93, 96, 98, 99, 101, 102, 103, 105, 106, 108, 111, 112, 114, 115, 116, 119, 120, 121, 122, 123, 124, 125, 126, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 148, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 164, 165, 166, 167, 168, 171, 172, 173, 174, 175, 176, 177, 178, 179, 181, 185, 186, 187, 190, 191, 192, 196, 198, 203, 204, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 246, 248, 249, 250, 251, 253, 254, 255, 256, 257, 259, 260, 261, 262, 264, 265, 266, 267, 268, 269, 270, 271, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 295, 296, 298, 299, 300, 301, 302, 303, 304, 306, 307, 308, 309, 310, 312, 313, 314, 316, 317, 318, 319, 320, 321, 322, 323, 324, 326, 327, 329, 330, 331, 332, 334, 335, 336, 337, 340, 341, 342, 344, 346, 347, 348, 349, 350, 351, 352, 353, 354, 356, 357, 359, 360, 361, 362, 363, 364, 365, 366, 369, 370, 372, 374, 376, 377, 378, 380, 381, 382, 384, 385, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 421, 422, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433], "outer_test_idx": [29, 32, 35, 46, 48, 53, 63, 66, 67, 68, 69, 78, 83, 85, 86, 88, 94, 95, 97, 100, 104, 107, 109, 110, 113, 117, 118, 127, 128, 146, 147, 149, 150, 163, 169, 170, 180, 182, 183, 184, 188, 189, 193, 194, 195, 197, 199, 200, 201, 202, 205, 214, 228, 229, 245, 247, 252, 258, 263, 272, 293, 294, 297, 305, 311, 315, 325, 328, 333, 338, 339, 343, 345, 355, 358, 367, 368, 371, 373, 375, 379, 383, 386, 408, 409, 420, 423]}

View File

@ -0,0 +1,42 @@
{
"size": {
"n_samples": 87,
"mse": 0.20102046206409732,
"rmse": 0.44835305515196094,
"mae": 0.2436335881551107,
"r2": 0.1900957049146058
},
"delivery": {
"n_samples": 62,
"mse": 0.5899936276993041,
"rmse": 0.7681104267612203,
"mae": 0.46896366539576484,
"r2": 0.4017743543115936
},
"pdi": {
"n_samples": 87,
"accuracy": 0.6666666666666666,
"precision": 0.41556437389770723,
"recall": 0.4042119565217391,
"f1": 0.40777777777777774
},
"ee": {
"n_samples": 87,
"accuracy": 0.6666666666666666,
"precision": 0.6414141414141414,
"recall": 0.6782661782661782,
"f1": 0.6387485970819304
},
"toxic": {
"n_samples": 62,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n_samples": 62,
"kl_divergence": 0.30758161166563297,
"js_divergence": 0.08759221465023677
}
}

View File

@ -0,0 +1,10 @@
{
"d_model": 512,
"num_heads": 8,
"n_attn_layers": 4,
"fusion_strategy": "attention",
"head_hidden_dim": 64,
"dropout": 0.11433976976282646,
"lr": 5.3015812445144804e-05,
"weight_decay": 6.704431817743382e-06
}

View File

@ -0,0 +1 @@
{"epoch_mean": 18}

View File

@ -0,0 +1,167 @@
{
"train": [
{
"loss": 16.328427488153633,
"loss_size": 10.961490392684937,
"loss_pdi": 1.3189676566557451,
"loss_ee": 1.0652603994716296,
"loss_delivery": 1.093609056012197,
"loss_biodist": 1.2020217722112483,
"loss_toxic": 0.6870784759521484
},
{
"loss": 5.557661620053378,
"loss_size": 0.6549029702490027,
"loss_pdi": 1.2172227664427324,
"loss_ee": 1.0503052473068237,
"loss_delivery": 0.9982107078487222,
"loss_biodist": 1.0446247458457947,
"loss_toxic": 0.5923952026800676
},
{
"loss": 4.66534686088562,
"loss_size": 0.40050424906340515,
"loss_pdi": 0.9867187413302335,
"loss_ee": 0.9837820909240029,
"loss_delivery": 1.079543113708496,
"loss_biodist": 0.7991420409896157,
"loss_toxic": 0.41565650972453033
},
{
"loss": 4.042153250087392,
"loss_size": 0.39615295827388763,
"loss_pdi": 0.8758815581148321,
"loss_ee": 0.9381269162351434,
"loss_delivery": 0.9179368994452737,
"loss_biodist": 0.6232685229995034,
"loss_toxic": 0.2907863679257306
},
{
"loss": 3.5052005377682773,
"loss_size": 0.3407063511284915,
"loss_pdi": 0.8075345104390924,
"loss_ee": 0.8666415322910656,
"loss_delivery": 0.821596086025238,
"loss_biodist": 0.49901550195433875,
"loss_toxic": 0.16970657895911823
},
{
"loss": 3.227808865633878,
"loss_size": 0.3172220086509531,
"loss_pdi": 0.7554353800686923,
"loss_ee": 0.8366058685562827,
"loss_delivery": 0.7592829249121926,
"loss_biodist": 0.41754513708027924,
"loss_toxic": 0.14171750030734323
},
{
"loss": 3.015385866165161,
"loss_size": 0.3014552891254425,
"loss_pdi": 0.7061856659975919,
"loss_ee": 0.8024854118173773,
"loss_delivery": 0.6936024874448776,
"loss_biodist": 0.37025675448504364,
"loss_toxic": 0.1414001943035559
},
{
"loss": 2.8741718639026987,
"loss_size": 0.30145836960185657,
"loss_pdi": 0.6485596244985407,
"loss_ee": 0.7874462875452909,
"loss_delivery": 0.7069157646460966,
"loss_biodist": 0.31511187282475556,
"loss_toxic": 0.11467998136173595
},
{
"loss": 2.7043218179182573,
"loss_size": 0.30878364227034827,
"loss_pdi": 0.6446920850060203,
"loss_ee": 0.763309196992354,
"loss_delivery": 0.6044047027826309,
"loss_biodist": 0.2966968539086255,
"loss_toxic": 0.08643539994955063
},
{
"loss": 2.592982042919506,
"loss_size": 0.301474930210547,
"loss_pdi": 0.586654855446382,
"loss_ee": 0.7498630989681591,
"loss_delivery": 0.5886639884927056,
"loss_biodist": 0.2792905040762641,
"loss_toxic": 0.08703470907428047
},
{
"loss": 2.475934158671986,
"loss_size": 0.3227222941138528,
"loss_pdi": 0.5525102777914568,
"loss_ee": 0.7191228595646945,
"loss_delivery": 0.5444187271324071,
"loss_biodist": 0.254928475076502,
"loss_toxic": 0.08223151076923717
},
{
"loss": 2.4397390105507593,
"loss_size": 0.28487288003618066,
"loss_pdi": 0.5408996912566099,
"loss_ee": 0.7313735810193148,
"loss_delivery": 0.5359987101771615,
"loss_biodist": 0.2598937614397569,
"loss_toxic": 0.08670033531432803
},
{
"loss": 2.454329328103499,
"loss_size": 0.2546617639335719,
"loss_pdi": 0.549032907594334,
"loss_ee": 0.723411272872578,
"loss_delivery": 0.6058986783027649,
"loss_biodist": 0.2590403082695874,
"loss_toxic": 0.06228439848531376
},
{
"loss": 2.459202441302213,
"loss_size": 0.30333411490375345,
"loss_pdi": 0.5448255024173043,
"loss_ee": 0.7321870706298135,
"loss_delivery": 0.5674931427294557,
"loss_biodist": 0.2446274757385254,
"loss_toxic": 0.06673513505269181
},
{
"loss": 2.4072633764960547,
"loss_size": 0.3075956946069544,
"loss_pdi": 0.5372236939993772,
"loss_ee": 0.6942113583738153,
"loss_delivery": 0.541759736158631,
"loss_biodist": 0.2502128630876541,
"loss_toxic": 0.07625993527472019
},
{
"loss": 2.3814159631729126,
"loss_size": 0.2884528948502107,
"loss_pdi": 0.5177338475530798,
"loss_ee": 0.7055766040628607,
"loss_delivery": 0.564139807766134,
"loss_biodist": 0.24335274642164056,
"loss_toxic": 0.062160092321309174
},
{
"loss": 2.288854566487399,
"loss_size": 0.2864968058737842,
"loss_pdi": 0.5151266103441065,
"loss_ee": 0.6858331777832725,
"loss_delivery": 0.492813152345744,
"loss_biodist": 0.23307398774407126,
"loss_toxic": 0.07551077617840334
},
{
"loss": 2.308527339588512,
"loss_size": 0.3206997500224547,
"loss_pdi": 0.48468891328031366,
"loss_ee": 0.6778702302412554,
"loss_delivery": 0.5191753757270899,
"loss_biodist": 0.2486932264132933,
"loss_toxic": 0.05739984508942474
}
],
"val": []
}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6285ba67132f7fff122c780ba4a63a2f52ed2b8a8775342015c1ccf00e2351e9
size 107113964

View File

@ -0,0 +1 @@
{"outer_train_idx": [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 18, 20, 21, 22, 23, 24, 25, 27, 29, 31, 32, 35, 36, 39, 40, 41, 42, 43, 45, 46, 47, 48, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 63, 65, 66, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 78, 80, 81, 83, 84, 85, 86, 87, 88, 89, 91, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 133, 134, 135, 137, 138, 139, 140, 141, 142, 143, 145, 146, 147, 148, 149, 150, 151, 153, 154, 155, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 199, 200, 201, 202, 203, 204, 205, 206, 207, 209, 210, 212, 213, 214, 215, 216, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 234, 235, 239, 240, 241, 242, 243, 244, 245, 246, 247, 249, 250, 252, 253, 254, 255, 258, 259, 260, 262, 263, 264, 265, 267, 268, 272, 274, 275, 276, 277, 278, 280, 281, 283, 284, 285, 286, 287, 289, 290, 291, 293, 294, 296, 297, 298, 299, 300, 301, 305, 308, 309, 310, 311, 313, 314, 315, 317, 318, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 342, 343, 344, 345, 348, 349, 352, 353, 355, 356, 357, 358, 359, 360, 361, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 375, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 392, 395, 396, 397, 399, 400, 401, 403, 404, 405, 406, 407, 408, 409, 410, 412, 413, 414, 415, 416, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431], "outer_test_idx": [3, 14, 15, 17, 19, 26, 28, 30, 33, 34, 37, 38, 44, 49, 50, 62, 64, 73, 79, 82, 90, 92, 93, 105, 106, 121, 132, 136, 144, 152, 156, 158, 173, 187, 198, 208, 211, 217, 218, 233, 236, 237, 238, 248, 251, 256, 257, 261, 266, 269, 270, 271, 273, 279, 282, 288, 292, 295, 302, 303, 304, 306, 307, 312, 316, 319, 340, 341, 346, 347, 350, 351, 354, 362, 374, 376, 390, 391, 393, 394, 398, 402, 411, 417, 418, 432, 433]}

View File

@ -0,0 +1,42 @@
{
"size": {
"n_samples": 85,
"mse": 0.07855156229299931,
"rmse": 0.28027051627490057,
"mae": 0.2201253890991211,
"r2": 0.009407939412061861
},
"delivery": {
"n_samples": 62,
"mse": 0.4162270771403472,
"rmse": 0.645156629928227,
"mae": 0.41306305523856635,
"r2": 0.37384819758564136
},
"pdi": {
"n_samples": 87,
"accuracy": 0.7126436781609196,
"precision": 0.3963383838383838,
"recall": 0.5856060606060606,
"f1": 0.43013891331106036
},
"ee": {
"n_samples": 87,
"accuracy": 0.6781609195402298,
"precision": 0.6215366001209922,
"recall": 0.65781362712309,
"f1": 0.6235867752721687
},
"toxic": {
"n_samples": 62,
"accuracy": 0.967741935483871,
"precision": 0.8,
"recall": 0.9830508474576272,
"f1": 0.8663793103448275
},
"biodist": {
"n_samples": 62,
"kl_divergence": 0.3189032824921648,
"js_divergence": 0.07944133611635379
}
}

View File

@ -0,0 +1,10 @@
{
"d_model": 512,
"num_heads": 4,
"n_attn_layers": 5,
"fusion_strategy": "attention",
"head_hidden_dim": 64,
"dropout": 0.11746271741188277,
"lr": 0.0001939220403760229,
"weight_decay": 2.4722550292920085e-06
}

View File

@ -0,0 +1 @@
{"epoch_mean": 10}

View File

@ -0,0 +1,95 @@
{
"train": [
{
"loss": 9.790851376273416,
"loss_size": 4.886490476402369,
"loss_pdi": 1.3631455573168667,
"loss_ee": 1.088216781616211,
"loss_delivery": 0.7040194340727546,
"loss_biodist": 1.1505104249173945,
"loss_toxic": 0.5984687263315375
},
{
"loss": 4.359956351193515,
"loss_size": 0.43619183789600025,
"loss_pdi": 1.1665386232462795,
"loss_ee": 0.9901693842627786,
"loss_delivery": 0.6073603630065918,
"loss_biodist": 0.7433811046860435,
"loss_toxic": 0.41631498661908234
},
{
"loss": 3.756383180618286,
"loss_size": 0.4402667825872248,
"loss_pdi": 1.0114682045849888,
"loss_ee": 0.9900745803659613,
"loss_delivery": 0.5448389879681848,
"loss_biodist": 0.5919120474295183,
"loss_toxic": 0.17782256481322375
},
{
"loss": 3.4846310182051226,
"loss_size": 0.46992606737396936,
"loss_pdi": 1.000607517632571,
"loss_ee": 0.8866635181687095,
"loss_delivery": 0.5218790579925884,
"loss_biodist": 0.47943955388936127,
"loss_toxic": 0.12611534412611614
},
{
"loss": 3.0496469844471323,
"loss_size": 0.36307603120803833,
"loss_pdi": 0.8360648371956565,
"loss_ee": 0.8720264759930697,
"loss_delivery": 0.45130031081763183,
"loss_biodist": 0.4275714944709431,
"loss_toxic": 0.09960775822401047
},
{
"loss": 2.8820750496604224,
"loss_size": 0.35799529267983005,
"loss_pdi": 0.7969891158017245,
"loss_ee": 0.8362645967440172,
"loss_delivery": 0.4626781473105604,
"loss_biodist": 0.358451091430404,
"loss_toxic": 0.0696967878294262
},
{
"loss": 2.645532087846236,
"loss_size": 0.3313806341453032,
"loss_pdi": 0.7585093595764854,
"loss_ee": 0.797249972820282,
"loss_delivery": 0.38693068108775397,
"loss_biodist": 0.29899653656916186,
"loss_toxic": 0.07246483176607978
},
{
"loss": 2.5514606779271904,
"loss_size": 0.33785366063768213,
"loss_pdi": 0.710435076193376,
"loss_ee": 0.7790363105860624,
"loss_delivery": 0.3592266860333356,
"loss_biodist": 0.2742794535376809,
"loss_toxic": 0.09062948763709176
},
{
"loss": 2.4680945114655928,
"loss_size": 0.3564724339680238,
"loss_pdi": 0.6821299493312836,
"loss_ee": 0.7459230910647999,
"loss_delivery": 0.36073175072669983,
"loss_biodist": 0.26153207096186554,
"loss_toxic": 0.061305238187990406
},
{
"loss": 2.3807498541745273,
"loss_size": 0.29239929941567505,
"loss_pdi": 0.7071054875850677,
"loss_ee": 0.7487604834816672,
"loss_delivery": 0.34709482369097794,
"loss_biodist": 0.2424463846466758,
"loss_toxic": 0.04294342412190004
}
],
"val": []
}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:93cd6e31c539b5604fa03b14699fc985cbee10dd8cbbb05f6cebc43f65e394da
size 132340732

View File

@ -0,0 +1 @@
{"outer_train_idx": [0, 1, 2, 3, 7, 8, 12, 13, 14, 15, 17, 19, 20, 21, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 46, 48, 49, 50, 51, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 76, 78, 79, 80, 82, 83, 85, 86, 88, 89, 90, 91, 92, 93, 94, 95, 97, 99, 100, 101, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 127, 128, 129, 130, 132, 133, 134, 136, 138, 140, 143, 144, 145, 146, 147, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 175, 177, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 197, 198, 199, 200, 201, 202, 203, 205, 206, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 220, 222, 225, 226, 227, 228, 229, 231, 232, 233, 235, 236, 237, 238, 239, 240, 241, 242, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 265, 266, 267, 269, 270, 271, 272, 273, 276, 277, 278, 279, 280, 281, 282, 284, 285, 286, 288, 289, 292, 293, 294, 295, 296, 297, 299, 300, 301, 302, 303, 304, 305, 306, 307, 309, 310, 311, 312, 314, 315, 316, 318, 319, 320, 321, 324, 325, 328, 329, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 365, 366, 367, 368, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 383, 384, 385, 386, 387, 389, 390, 391, 393, 394, 396, 397, 398, 401, 402, 405, 406, 407, 408, 409, 410, 411, 413, 414, 415, 416, 417, 418, 420, 421, 423, 424, 425, 427, 428, 430, 431, 432, 433], "outer_test_idx": [4, 5, 6, 9, 10, 11, 16, 18, 22, 23, 24, 39, 40, 45, 47, 52, 70, 75, 77, 81, 84, 87, 96, 98, 102, 114, 126, 131, 135, 137, 139, 141, 142, 148, 155, 172, 176, 178, 179, 196, 204, 207, 219, 221, 223, 224, 230, 234, 243, 244, 264, 268, 274, 275, 283, 287, 290, 291, 298, 308, 313, 317, 322, 323, 326, 327, 330, 331, 332, 349, 360, 364, 369, 370, 382, 388, 392, 395, 399, 400, 403, 404, 412, 419, 422, 426, 429]}

View File

@ -0,0 +1,42 @@
{
"size": {
"n_samples": 87,
"mse": 0.08391054707837464,
"rmse": 0.28967317286620564,
"mae": 0.22691357272794876,
"r2": 0.26719931627457305
},
"delivery": {
"n_samples": 62,
"mse": 2.0453160934060053,
"rmse": 1.4301454798047664,
"mae": 0.5443450972558029,
"r2": 0.0777919381248442
},
"pdi": {
"n_samples": 87,
"accuracy": 0.6896551724137931,
"precision": 0.3994245524296675,
"recall": 0.5978021978021978,
"f1": 0.417895167895168
},
"ee": {
"n_samples": 87,
"accuracy": 0.5862068965517241,
"precision": 0.5469462969462969,
"recall": 0.5874125874125874,
"f1": 0.5375569894616209
},
"toxic": {
"n_samples": 63,
"accuracy": 0.9682539682539683,
"precision": 0.8,
"recall": 0.9833333333333334,
"f1": 0.8665254237288135
},
"biodist": {
"n_samples": 63,
"kl_divergence": 0.28367776789683485,
"js_divergence": 0.07318286384043993
}
}

View File

@ -0,0 +1,10 @@
{
"d_model": 128,
"num_heads": 4,
"n_attn_layers": 6,
"fusion_strategy": "max",
"head_hidden_dim": 128,
"dropout": 0.15658744776638445,
"lr": 0.00031005155898680676,
"weight_decay": 5.422040924441196e-05
}

View File

@ -0,0 +1 @@
{"epoch_mean": 14}

View File

@ -0,0 +1,131 @@
{
"train": [
{
"loss": 15.068541526794434,
"loss_size": 9.661331350153143,
"loss_pdi": 1.3961028402501887,
"loss_ee": 1.1041727607900447,
"loss_delivery": 1.131193908778104,
"loss_biodist": 1.1299291415648027,
"loss_toxic": 0.6458113518628207
},
{
"loss": 5.692212278192693,
"loss_size": 0.9680004715919495,
"loss_pdi": 1.2202669165351174,
"loss_ee": 1.0751874880357222,
"loss_delivery": 1.075642000545155,
"loss_biodist": 0.8836204951459711,
"loss_toxic": 0.46949480880390515
},
{
"loss": 4.2831949970938945,
"loss_size": 0.4278720129619945,
"loss_pdi": 1.0048133405772122,
"loss_ee": 0.9688023762269453,
"loss_delivery": 0.9665126570246436,
"loss_biodist": 0.649872666055506,
"loss_toxic": 0.2653219618580558
},
{
"loss": 3.620699340646917,
"loss_size": 0.34346921877427533,
"loss_pdi": 0.909794731573625,
"loss_ee": 0.8586091019890525,
"loss_delivery": 0.8243018510666761,
"loss_biodist": 0.5371287275444377,
"loss_toxic": 0.14739569039507347
},
{
"loss": 3.2791861187327993,
"loss_size": 0.3260936520316384,
"loss_pdi": 0.8282042199915106,
"loss_ee": 0.8434794707731768,
"loss_delivery": 0.7491147694262591,
"loss_biodist": 0.4268476908857172,
"loss_toxic": 0.10544633001766422
},
{
"loss": 3.000173742120916,
"loss_size": 0.3082492568276145,
"loss_pdi": 0.6911327242851257,
"loss_ee": 0.7926767305894331,
"loss_delivery": 0.7263242087580941,
"loss_biodist": 0.3646425713192333,
"loss_toxic": 0.11714824627746236
},
{
"loss": 2.8449384082447398,
"loss_size": 0.3491906361146407,
"loss_pdi": 0.7018052475018934,
"loss_ee": 0.7752535722472451,
"loss_delivery": 0.6207486174323342,
"loss_biodist": 0.3087055358019742,
"loss_toxic": 0.08923475528982552
},
{
"loss": 2.762920791452581,
"loss_size": 0.27234369922767987,
"loss_pdi": 0.6848637001080946,
"loss_ee": 0.7332382093776356,
"loss_delivery": 0.7119220888072794,
"loss_biodist": 0.27469146658073773,
"loss_toxic": 0.08586158154701645
},
{
"loss": 2.6228579716248945,
"loss_size": 0.28220029175281525,
"loss_pdi": 0.6514047113331881,
"loss_ee": 0.7080714106559753,
"loss_delivery": 0.6849517280405218,
"loss_biodist": 0.23796464096416126,
"loss_toxic": 0.05826510641385208
},
{
"loss": 2.53473145311529,
"loss_size": 0.24789857864379883,
"loss_pdi": 0.6621171263131228,
"loss_ee": 0.7045791528441689,
"loss_delivery": 0.6228310723196376,
"loss_biodist": 0.24999592927369205,
"loss_toxic": 0.047309636138379574
},
{
"loss": 2.439385262402621,
"loss_size": 0.2714946825395931,
"loss_pdi": 0.6569395525888964,
"loss_ee": 0.679518764669245,
"loss_delivery": 0.5314246158708226,
"loss_biodist": 0.23756521533836017,
"loss_toxic": 0.062442497265609825
},
{
"loss": 2.4029004248705776,
"loss_size": 0.25774127651344647,
"loss_pdi": 0.5974260785362937,
"loss_ee": 0.6959952928803184,
"loss_delivery": 0.5820360441099514,
"loss_biodist": 0.22288588908585635,
"loss_toxic": 0.04681587845764377
},
{
"loss": 2.390383416956121,
"loss_size": 0.24576269225640732,
"loss_pdi": 0.6070916679772463,
"loss_ee": 0.6997986544262279,
"loss_delivery": 0.5448949479243972,
"loss_biodist": 0.2274868596683849,
"loss_toxic": 0.065348598936742
},
{
"loss": 2.450988466089422,
"loss_size": 0.21830374882979828,
"loss_pdi": 0.6097230477766558,
"loss_ee": 0.6919564171270891,
"loss_delivery": 0.6387598026882518,
"loss_biodist": 0.22940932078794998,
"loss_toxic": 0.06283610728992657
}
],
"val": []
}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cd138dcfe3697ec06e8caa6ec6c35a16635b1eb5a790b07e5c49eaf638dd9379
size 11112658

View File

@ -0,0 +1 @@
{"outer_train_idx": [3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 42, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 60, 62, 63, 64, 66, 67, 68, 69, 70, 72, 73, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 90, 92, 93, 94, 95, 96, 97, 98, 100, 101, 102, 104, 105, 106, 107, 109, 110, 111, 113, 114, 116, 117, 118, 120, 121, 122, 124, 125, 126, 127, 128, 130, 131, 132, 133, 135, 136, 137, 139, 141, 142, 144, 145, 146, 147, 148, 149, 150, 151, 152, 154, 155, 156, 157, 158, 160, 161, 162, 163, 164, 165, 166, 168, 169, 170, 172, 173, 174, 176, 178, 179, 180, 181, 182, 183, 184, 187, 188, 189, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 217, 218, 219, 220, 221, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 238, 240, 241, 242, 243, 244, 245, 246, 247, 248, 250, 251, 252, 253, 254, 256, 257, 258, 259, 261, 263, 264, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 283, 285, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 302, 303, 304, 305, 306, 307, 308, 311, 312, 313, 315, 316, 317, 318, 319, 320, 321, 322, 323, 325, 326, 327, 328, 330, 331, 332, 333, 334, 336, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 349, 350, 351, 354, 355, 356, 357, 358, 360, 361, 362, 364, 367, 368, 369, 370, 371, 373, 374, 375, 376, 377, 378, 379, 382, 383, 385, 386, 388, 390, 391, 392, 393, 394, 395, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 411, 412, 414, 417, 418, 419, 420, 422, 423, 424, 426, 427, 429, 430, 431, 432, 433], "outer_test_idx": [0, 1, 2, 8, 20, 21, 27, 41, 43, 56, 57, 58, 59, 61, 65, 71, 74, 89, 91, 99, 103, 108, 112, 115, 119, 123, 129, 134, 138, 140, 143, 153, 159, 167, 171, 175, 177, 185, 186, 190, 191, 192, 203, 215, 216, 222, 227, 231, 239, 249, 255, 260, 262, 265, 276, 284, 286, 301, 309, 310, 314, 324, 329, 335, 337, 348, 352, 353, 359, 363, 365, 366, 372, 380, 381, 384, 387, 389, 396, 410, 413, 415, 416, 421, 425, 428]}

View File

@ -0,0 +1,42 @@
{
"size": {
"n_samples": 86,
"mse": 0.17175675509369362,
"rmse": 0.41443546553558086,
"mae": 0.2695355332174966,
"r2": -0.28030669239092476
},
"delivery": {
"n_samples": 63,
"mse": 0.35801504662479033,
"rmse": 0.5983435857638906,
"mae": 0.4112293718545328,
"r2": 0.3391331751101827
},
"pdi": {
"n_samples": 86,
"accuracy": 0.7209302325581395,
"precision": 0.5444444444444444,
"recall": 0.746031746031746,
"f1": 0.5894308943089431
},
"ee": {
"n_samples": 86,
"accuracy": 0.5930232558139535,
"precision": 0.5120650953984287,
"recall": 0.5383076043453402,
"f1": 0.5184148203294006
},
"toxic": {
"n_samples": 64,
"accuracy": 0.9375,
"precision": 0.6666666666666666,
"recall": 0.967741935483871,
"f1": 0.7333333333333333
},
"biodist": {
"n_samples": 63,
"kl_divergence": 0.2828013605689281,
"js_divergence": 0.07105864428768947
}
}

View File

@ -0,0 +1,60 @@
{
"original_strata_counts": {
"T0|P0|E0": "5",
"T0|P0|E1": "63",
"T0|P0|E2": "175",
"T0|P1|E0": "1",
"T0|P1|E1": "19",
"T0|P1|E2": "33",
"T0|P2|E2": "3",
"T1|P0|E2": "9",
"T1|P1|E2": "5",
"TNA|P0|E0": "28",
"TNA|P0|E1": "21",
"TNA|P0|E2": "20",
"TNA|P1|E0": "29",
"TNA|P1|E1": "7",
"TNA|P1|E2": "14",
"TNA|P2|E2": "1",
"TNA|P3|E0": "1"
},
"rare_strata": [
"T0|P1|E0",
"T0|P2|E2",
"TNA|P2|E2",
"TNA|P3|E0"
],
"final_strata": [
"RARE",
"T0|P0|E0",
"T0|P0|E1",
"T0|P0|E2",
"T0|P1|E1",
"T0|P1|E2",
"T1|P0|E2",
"T1|P1|E2",
"TNA|P0|E0",
"TNA|P0|E1",
"TNA|P0|E2",
"TNA|P1|E0",
"TNA|P1|E1",
"TNA|P1|E2"
],
"final_strata_counts": {
"RARE": "6",
"T0|P0|E0": "5",
"T0|P0|E1": "63",
"T0|P0|E2": "175",
"T0|P1|E1": "19",
"T0|P1|E2": "33",
"T1|P0|E2": "9",
"T1|P1|E2": "5",
"TNA|P0|E0": "28",
"TNA|P0|E1": "21",
"TNA|P0|E2": "20",
"TNA|P1|E0": "29",
"TNA|P1|E1": "7",
"TNA|P1|E2": "14"
},
"n_rare_merged": "6"
}

View File

@ -0,0 +1,342 @@
{
"fold_results": [
{
"fold": 0,
"best_params": {
"d_model": 512,
"num_heads": 8,
"n_attn_layers": 5,
"fusion_strategy": "attention",
"head_hidden_dim": 128,
"dropout": 0.14666736316838325,
"lr": 0.0001295888795454003,
"weight_decay": 7.732380983243132e-05
},
"epoch_mean": 13,
"test_metrics": {
"size": {
"n_samples": 87,
"mse": 0.26366087871128757,
"rmse": 0.5134791901443403,
"mae": 0.25157783223294666,
"r2": 0.21208517006410577
},
"delivery": {
"n_samples": 61,
"mse": 0.40443344562739025,
"rmse": 0.63595082013265,
"mae": 0.3928790920429298,
"r2": 0.2300258531372983
},
"pdi": {
"n_samples": 87,
"accuracy": 0.7241379310344828,
"precision": 0.35141509433962265,
"recall": 0.35351966873706003,
"f1": 0.348405985686402
},
"ee": {
"n_samples": 87,
"accuracy": 0.6666666666666666,
"precision": 0.6188811188811189,
"recall": 0.6375291375291375,
"f1": 0.6217948717948718
},
"toxic": {
"n_samples": 62,
"accuracy": 0.967741935483871,
"precision": 0.8,
"recall": 0.9830508474576272,
"f1": 0.8663793103448275
},
"biodist": {
"n_samples": 61,
"kl_divergence": 0.14776465556036145,
"js_divergence": 0.03926150329301917
}
}
},
{
"fold": 1,
"best_params": {
"d_model": 512,
"num_heads": 4,
"n_attn_layers": 2,
"fusion_strategy": "avg",
"head_hidden_dim": 64,
"dropout": 0.05188345993471756,
"lr": 4.21188892865021e-05,
"weight_decay": 4.086499445232577e-05
},
"epoch_mean": 23,
"test_metrics": {
"size": {
"n_samples": 87,
"mse": 0.20102046206409732,
"rmse": 0.44835305515196094,
"mae": 0.2436335881551107,
"r2": 0.1900957049146058
},
"delivery": {
"n_samples": 62,
"mse": 0.5899936276993041,
"rmse": 0.7681104267612203,
"mae": 0.46896366539576484,
"r2": 0.4017743543115936
},
"pdi": {
"n_samples": 87,
"accuracy": 0.6666666666666666,
"precision": 0.41556437389770723,
"recall": 0.4042119565217391,
"f1": 0.40777777777777774
},
"ee": {
"n_samples": 87,
"accuracy": 0.6666666666666666,
"precision": 0.6414141414141414,
"recall": 0.6782661782661782,
"f1": 0.6387485970819304
},
"toxic": {
"n_samples": 62,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n_samples": 62,
"kl_divergence": 0.30758161166563297,
"js_divergence": 0.08759221465023677
}
}
},
{
"fold": 2,
"best_params": {
"d_model": 512,
"num_heads": 8,
"n_attn_layers": 4,
"fusion_strategy": "attention",
"head_hidden_dim": 64,
"dropout": 0.11433976976282646,
"lr": 5.3015812445144804e-05,
"weight_decay": 6.704431817743382e-06
},
"epoch_mean": 18,
"test_metrics": {
"size": {
"n_samples": 85,
"mse": 0.07855156229299931,
"rmse": 0.28027051627490057,
"mae": 0.2201253890991211,
"r2": 0.009407939412061861
},
"delivery": {
"n_samples": 62,
"mse": 0.4162270771403472,
"rmse": 0.645156629928227,
"mae": 0.41306305523856635,
"r2": 0.37384819758564136
},
"pdi": {
"n_samples": 87,
"accuracy": 0.7126436781609196,
"precision": 0.3963383838383838,
"recall": 0.5856060606060606,
"f1": 0.43013891331106036
},
"ee": {
"n_samples": 87,
"accuracy": 0.6781609195402298,
"precision": 0.6215366001209922,
"recall": 0.65781362712309,
"f1": 0.6235867752721687
},
"toxic": {
"n_samples": 62,
"accuracy": 0.967741935483871,
"precision": 0.8,
"recall": 0.9830508474576272,
"f1": 0.8663793103448275
},
"biodist": {
"n_samples": 62,
"kl_divergence": 0.3189032824921648,
"js_divergence": 0.07944133611635379
}
}
},
{
"fold": 3,
"best_params": {
"d_model": 512,
"num_heads": 4,
"n_attn_layers": 5,
"fusion_strategy": "attention",
"head_hidden_dim": 64,
"dropout": 0.11746271741188277,
"lr": 0.0001939220403760229,
"weight_decay": 2.4722550292920085e-06
},
"epoch_mean": 10,
"test_metrics": {
"size": {
"n_samples": 87,
"mse": 0.08391054707837464,
"rmse": 0.28967317286620564,
"mae": 0.22691357272794876,
"r2": 0.26719931627457305
},
"delivery": {
"n_samples": 62,
"mse": 2.0453160934060053,
"rmse": 1.4301454798047664,
"mae": 0.5443450972558029,
"r2": 0.0777919381248442
},
"pdi": {
"n_samples": 87,
"accuracy": 0.6896551724137931,
"precision": 0.3994245524296675,
"recall": 0.5978021978021978,
"f1": 0.417895167895168
},
"ee": {
"n_samples": 87,
"accuracy": 0.5862068965517241,
"precision": 0.5469462969462969,
"recall": 0.5874125874125874,
"f1": 0.5375569894616209
},
"toxic": {
"n_samples": 63,
"accuracy": 0.9682539682539683,
"precision": 0.8,
"recall": 0.9833333333333334,
"f1": 0.8665254237288135
},
"biodist": {
"n_samples": 63,
"kl_divergence": 0.28367776789683485,
"js_divergence": 0.07318286384043993
}
}
},
{
"fold": 4,
"best_params": {
"d_model": 128,
"num_heads": 4,
"n_attn_layers": 6,
"fusion_strategy": "max",
"head_hidden_dim": 128,
"dropout": 0.15658744776638445,
"lr": 0.00031005155898680676,
"weight_decay": 5.422040924441196e-05
},
"epoch_mean": 14,
"test_metrics": {
"size": {
"n_samples": 86,
"mse": 0.17175675509369362,
"rmse": 0.41443546553558086,
"mae": 0.2695355332174966,
"r2": -0.28030669239092476
},
"delivery": {
"n_samples": 63,
"mse": 0.35801504662479033,
"rmse": 0.5983435857638906,
"mae": 0.4112293718545328,
"r2": 0.3391331751101827
},
"pdi": {
"n_samples": 86,
"accuracy": 0.7209302325581395,
"precision": 0.5444444444444444,
"recall": 0.746031746031746,
"f1": 0.5894308943089431
},
"ee": {
"n_samples": 86,
"accuracy": 0.5930232558139535,
"precision": 0.5120650953984287,
"recall": 0.5383076043453402,
"f1": 0.5184148203294006
},
"toxic": {
"n_samples": 64,
"accuracy": 0.9375,
"precision": 0.6666666666666666,
"recall": 0.967741935483871,
"f1": 0.7333333333333333
},
"biodist": {
"n_samples": 63,
"kl_divergence": 0.2828013605689281,
"js_divergence": 0.07105864428768947
}
}
}
],
"summary_stats": {
"size": {
"mse_mean": 0.1597800410480905,
"mse_std": 0.07069609368925868,
"rmse_mean": 0.3892422799945977,
"rmse_std": 0.09094222623565874,
"mae_mean": 0.24235718308652476,
"mae_std": 0.017652592224532318,
"r2_mean": 0.07969628765488435,
"r2_std": 0.19970720107145462
},
"delivery": {
"mse_mean": 0.7627970580995674,
"mse_std": 0.6460804607386303,
"rmse_mean": 0.8155413884781508,
"rmse_std": 0.31255287837212004,
"mae_mean": 0.44609605635751937,
"mae_std": 0.055343855526290356,
"r2_mean": 0.28451470365391207,
"r2_std": 0.11867334397685028
},
"pdi": {
"accuracy_mean": 0.7028067361668003,
"accuracy_std": 0.021722406295571848,
"precision_mean": 0.4214373697899651,
"precision_std": 0.06508897722192637,
"recall_mean": 0.5374343259397607,
"recall_std": 0.1421622175476529,
"f1_mean": 0.4387297477958702,
"f1_std": 0.0804178141542625
},
"ee": {
"accuracy_mean": 0.6381448810478482,
"accuracy_std": 0.039904343482039924,
"precision_mean": 0.5881686505521957,
"precision_std": 0.049765031116764336,
"recall_mean": 0.6198658269352666,
"recall_std": 0.05072984435881663,
"f1_mean": 0.5880204107879985,
"f1_std": 0.049740374945613494
},
"toxic": {
"accuracy_mean": 0.9682475678443421,
"accuracy_std": 0.01976937654694873,
"precision_mean": 0.8133333333333335,
"precision_std": 0.10666666666666666,
"recall_mean": 0.9834353927464917,
"recall_std": 0.010207614614579376,
"f1_mean": 0.8665234755503602,
"f1_std": 0.08432750219703594
},
"biodist": {
"kl_divergence_mean": 0.26814573563678445,
"kl_divergence_std": 0.06177240919631341,
"js_divergence_mean": 0.07010731243754784,
"js_divergence_std": 0.016460095953094674
}
}
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

211
pixi.lock
View File

@ -53,6 +53,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-64/xz-tools-5.8.1-hb9d3cd8_2.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda
- pypi: https://files.pythonhosted.org/packages/64/88/c7083fc61120ab661c5d0b82cb77079fc1429d3f913a456c1c82cf4658f7/alabaster-0.7.13-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/54/7e/ac0991d1745f7d755fc1cd381b3990a45b404b4d008fc75e2a983516fbfe/alembic-1.14.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9b/52/4a86a4fa1cc2aae79137cc9510b7080c3e5aede2310d14fae5486feec7f7/altair-5.4.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
@ -65,6 +66,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/bf/fa/cf5bb2409a385f78750e78c8d2e24780964976acdaaed65dbd6083ae5b40/charset_normalizer-3.4.4-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8e/71/7f20855592cc929bc206810432b991ec4c702dc26b0567b132e52c85536f/contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/15/1a/c6eae628480aa1fc5f6f85437c7d8ec0d1172597acd1c61182202a902c0f/cramjam-2.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl
@ -82,6 +84,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/da/71/ae30dadffc90b9006d77af76b393cb9dfbfc9629f339fc1574a1c52e6806/future-1.0.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/d8/88/0ce16c0afb2d71d85562a7bcd9b092fec80a7767ab5b5f7e1bbbca8200f8/greenlet-3.1.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl
@ -96,6 +99,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ee/07/44bd408781594c4d0a027666ef27fab1e441b109dc3b76b4f836f8fd04fe/jsonschema_specifications-2023.12.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/76/36/ae40d7a3171e06f55ac77fe5536079e7be1d8be2a8210e08975c7f9b4d54/kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c7/bd/50319665ce81bb10e90d1cf76f9e1aa269ea6f7fa30ab4521f14d122a3df/MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/30/33/cc27211d2ffeee4fd7402dca137b6e8a83f6dcae3d4be8d0ad5068555561/matplotlib-3.7.5-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl
@ -116,6 +120,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/46/0c/c75bbfb967457a0b7670b8ad267bfc4fffdf341c074e0a80db06c24ccfd4/nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7f/12/cba81286cbaf0f0c3f0473846cfd992cb240bdcea816bf2ef7de8ed0f744/optuna-4.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f8/7f/5b047effafbdd34e52c9e2d7e44f729a0655efafb22198c45cf692cdc157/pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/5d/e6/71ed4d95676098159b533c4a4c424cf453fec9614edaff1a0633fe228eef/pandas_flavor-0.7.0-py3-none-any.whl
@ -131,6 +136,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/25/a2/b725b61ac76a75583ae7104b3209f75ea44b13cfd026aa535ece22b7f22e/PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/3d/84/63b2e66f5c7cb97ce994769afbbef85a1ac364fedbcb7d4a3c0f15d318a5/rdkit-2024.3.5-cp38-cp38-manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/b7/59/2056f61236782a2c86b33906c025d4f4a0b17be0161b63b70fd9e8775d36/referencing-0.35.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl
@ -150,6 +156,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/2b/14/05f9206cf4e9cfca1afb5fd224c7cd434dcc3a433d6d9e4e0264d29c6cdb/sphinxcontrib_qthelp-1.0.3-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c6/77/5464ec50dd0f1c1037e3c93249b040c8fc8078fdda97530eeb02424b6eea/sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8a/7b/b9e0eb7f9a15f2e82856603c728edf14c54a07c6738ab228e4f2de049338/sqlalchemy-2.0.46-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/b6/c5/7ae467eeddb57260c8ce17a3a09f9f5edba35820fc022d7c55b7decd5d3a/starlette-0.44.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9a/14/857d0734989f3d26f2f965b2e3f67568ea7a6e8a60cb9c1ed7f774b6d606/streamlit-1.40.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl
@ -206,6 +213,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/xz-gpl-tools-5.8.1-h9a6d368_2.conda
- conda: https://conda.anaconda.org/conda-forge/osx-arm64/xz-tools-5.8.1-h39f12f2_2.conda
- pypi: https://files.pythonhosted.org/packages/64/88/c7083fc61120ab661c5d0b82cb77079fc1429d3f913a456c1c82cf4658f7/alabaster-0.7.13-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/54/7e/ac0991d1745f7d755fc1cd381b3990a45b404b4d008fc75e2a983516fbfe/alembic-1.14.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9b/52/4a86a4fa1cc2aae79137cc9510b7080c3e5aede2310d14fae5486feec7f7/altair-5.4.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl
@ -218,6 +226,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/0a/4e/3926a1c11f0433791985727965263f788af00db3482d89a7545ca5ecc921/charset_normalizer-3.4.4-cp38-cp38-macosx_10_9_universal2.whl
- pypi: https://files.pythonhosted.org/packages/a0/ad/72361203906e2dbe9baa776c64e9246d555b516808dd0cce385f07f4cf71/chemprop-1.7.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a6/82/29f5ff4ae074c3230e266bc9efef449ebde43721a727b989dd8ef8f97d73/contourpy-1.1.1-cp38-cp38-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/65/64/e34ee535519fd14cde3a7f3f8cd3b4ef54483b9df655e4180437eb884aab/cramjam-2.11.0-cp38-cp38-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl
- pypi: https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl
@ -249,6 +258,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ee/07/44bd408781594c4d0a027666ef27fab1e441b109dc3b76b4f836f8fd04fe/jsonschema_specifications-2023.12.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/14/a7/bb8ab10e12cc8764f4da0245d72dee4731cc720bdec0f085d5e9c6005b98/kiwisolver-1.4.7-cp38-cp38-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f8/ff/2c942a82c35a49df5de3a630ce0a8456ac2969691b230e530ac12314364c/MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl
- pypi: https://files.pythonhosted.org/packages/aa/59/4d13e5b6298b1ca5525eea8c68d3806ae93ab6d0bb17ca9846aa3156b92b/matplotlib-3.7.5-cp38-cp38-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl
@ -257,6 +267,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/a8/05/9d4f9b78ead6b2661d6e8ea772e111fc4a9fbd866ad0c81906c11206b55e/networkx-3.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/a7/ae/f53b7b265fdc701e663fbb322a8e9d4b14d9cb7b2385f45ddfabfc4327e4/numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/c0/da/977ded879c29cbd04de313843e76868e6e13408a94ed6b987245dc7c8506/openpyxl-3.1.5-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7f/12/cba81286cbaf0f0c3f0473846cfd992cb240bdcea816bf2ef7de8ed0f744/optuna-4.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/53/c3/f8e87361f7fdf42012def602bfa2a593423c729f5cb7c97aed7f51be66ac/pandas-2.0.3-cp38-cp38-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/5d/e6/71ed4d95676098159b533c4a4c424cf453fec9614edaff1a0633fe228eef/pandas_flavor-0.7.0-py3-none-any.whl
@ -272,6 +283,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz
- pypi: https://files.pythonhosted.org/packages/bf/cb/c709b60f4815e18c00e1e8639204bdba04cb158e6278791d82f94f51a988/rdkit-2024.3.5-cp38-cp38-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/b7/59/2056f61236782a2c86b33906c025d4f4a0b17be0161b63b70fd9e8775d36/referencing-0.35.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl
@ -291,6 +303,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/2b/14/05f9206cf4e9cfca1afb5fd224c7cd434dcc3a433d6d9e4e0264d29c6cdb/sphinxcontrib_qthelp-1.0.3-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/c6/77/5464ec50dd0f1c1037e3c93249b040c8fc8078fdda97530eeb02424b6eea/sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/07/d4/76b9618d7eb1e6a3c26734e1186f8ad7869e4426b1ea7dc425bc4c832e67/sqlalchemy-2.0.46-cp38-cp38-macosx_11_0_arm64.whl
- pypi: https://files.pythonhosted.org/packages/b6/c5/7ae467eeddb57260c8ce17a3a09f9f5edba35820fc022d7c55b7decd5d3a/starlette-0.44.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/9a/14/857d0734989f3d26f2f965b2e3f67568ea7a6e8a60cb9c1ed7f774b6d606/streamlit-1.40.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl
@ -336,6 +349,19 @@ packages:
version: 0.7.13
sha256: 1ee19aca801bbabb5ba3f5f258e4422dfa86f82f3e9cefb0859b283cdd7f62a3
requires_python: '>=3.6'
- pypi: https://files.pythonhosted.org/packages/54/7e/ac0991d1745f7d755fc1cd381b3990a45b404b4d008fc75e2a983516fbfe/alembic-1.14.1-py3-none-any.whl
name: alembic
version: 1.14.1
sha256: 1acdd7a3a478e208b0503cd73614d5e4c6efafa4e73518bb60e4f2846a37b1c5
requires_dist:
- sqlalchemy>=1.3.0
- mako
- importlib-metadata ; python_full_version < '3.9'
- importlib-resources ; python_full_version < '3.9'
- typing-extensions>=4
- backports-zoneinfo ; python_full_version < '3.9' and extra == 'tz'
- tzdata ; extra == 'tz'
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/9b/52/4a86a4fa1cc2aae79137cc9510b7080c3e5aede2310d14fae5486feec7f7/altair-5.4.1-py3-none-any.whl
name: altair
version: 5.4.1
@ -589,6 +615,18 @@ packages:
- pkg:pypi/colorama?source=hash-mapping
size: 25170
timestamp: 1666700778190
- pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl
name: colorlog
version: 6.10.1
sha256: 2d7e8348291948af66122cff006c9f8da6255d224e7cf8e37d8de2df3bad8c9c
requires_dist:
- colorama ; sys_platform == 'win32'
- black ; extra == 'development'
- flake8 ; extra == 'development'
- mypy ; extra == 'development'
- pytest ; extra == 'development'
- types-colorama ; extra == 'development'
requires_python: '>=3.6'
- pypi: https://files.pythonhosted.org/packages/8e/71/7f20855592cc929bc206810432b991ec4c702dc26b0567b132e52c85536f/contourpy-1.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
name: contourpy
version: 1.1.1
@ -1013,6 +1051,16 @@ packages:
- sphinx-rtd-theme ; extra == 'doc'
- sphinx-autodoc-typehints ; extra == 'doc'
requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/d8/88/0ce16c0afb2d71d85562a7bcd9b092fec80a7767ab5b5f7e1bbbca8200f8/greenlet-3.1.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
name: greenlet
version: 3.1.1
sha256: 85f3ff71e2e60bd4b4932a043fbbe0f499e263c628390b285cb599154a3b03b1
requires_dist:
- sphinx ; extra == 'docs'
- furo ; extra == 'docs'
- objgraph ; extra == 'test'
- psutil ; extra == 'test'
requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl
name: h11
version: 0.16.0
@ -1453,6 +1501,16 @@ packages:
- pkg:pypi/loguru?source=hash-mapping
size: 97617
timestamp: 1695547715271
- pypi: https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl
name: mako
version: 1.3.10
sha256: baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59
requires_dist:
- markupsafe>=0.9.2
- pytest ; extra == 'testing'
- babel ; extra == 'babel'
- lingua ; extra == 'lingua'
requires_python: '>=3.8'
- conda: https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_0.conda
sha256: c041b0eaf7a6af3344d5dd452815cdc148d6284fec25a4fa3f4263b3a021e962
md5: 93a8e71256479c62074356ef6ebf501b
@ -1709,6 +1767,73 @@ packages:
purls: []
size: 3108371
timestamp: 1762839712322
- pypi: https://files.pythonhosted.org/packages/7f/12/cba81286cbaf0f0c3f0473846cfd992cb240bdcea816bf2ef7de8ed0f744/optuna-4.5.0-py3-none-any.whl
name: optuna
version: 4.5.0
sha256: 5b8a783e84e448b0742501bc27195344a28d2c77bd2feef5b558544d954851b0
requires_dist:
- alembic>=1.5.0
- colorlog
- numpy
- packaging>=20.0
- sqlalchemy>=1.4.2
- tqdm
- pyyaml
- asv>=0.5.0 ; extra == 'benchmark'
- cma ; extra == 'benchmark'
- virtualenv ; extra == 'benchmark'
- black ; extra == 'checking'
- blackdoc ; extra == 'checking'
- flake8 ; extra == 'checking'
- isort ; extra == 'checking'
- mypy ; extra == 'checking'
- mypy-boto3-s3 ; extra == 'checking'
- scipy-stubs ; python_full_version >= '3.10' and extra == 'checking'
- types-pyyaml ; extra == 'checking'
- types-redis ; extra == 'checking'
- types-setuptools ; extra == 'checking'
- types-tqdm ; extra == 'checking'
- typing-extensions>=3.10.0.0 ; extra == 'checking'
- ase ; extra == 'document'
- cmaes>=0.12.0 ; extra == 'document'
- fvcore ; extra == 'document'
- kaleido<0.4 ; extra == 'document'
- lightgbm ; extra == 'document'
- matplotlib!=3.6.0 ; extra == 'document'
- pandas ; extra == 'document'
- pillow ; extra == 'document'
- plotly>=4.9.0 ; extra == 'document'
- scikit-learn ; extra == 'document'
- sphinx ; extra == 'document'
- sphinx-copybutton ; extra == 'document'
- sphinx-gallery ; extra == 'document'
- sphinx-notfound-page ; extra == 'document'
- sphinx-rtd-theme>=1.2.0 ; extra == 'document'
- torch ; extra == 'document'
- torchvision ; extra == 'document'
- boto3 ; extra == 'optional'
- cmaes>=0.12.0 ; extra == 'optional'
- google-cloud-storage ; extra == 'optional'
- matplotlib!=3.6.0 ; extra == 'optional'
- pandas ; extra == 'optional'
- plotly>=4.9.0 ; extra == 'optional'
- redis ; extra == 'optional'
- scikit-learn>=0.24.2 ; extra == 'optional'
- scipy ; extra == 'optional'
- torch ; extra == 'optional'
- grpcio ; extra == 'optional'
- protobuf>=5.28.1 ; extra == 'optional'
- coverage ; extra == 'test'
- fakeredis[lua] ; extra == 'test'
- kaleido<0.4 ; extra == 'test'
- moto ; extra == 'test'
- pytest ; extra == 'test'
- pytest-xdist ; extra == 'test'
- scipy>=1.9.2 ; extra == 'test'
- torch ; extra == 'test'
- grpcio ; extra == 'test'
- protobuf>=5.28.1 ; extra == 'test'
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl
name: packaging
version: '24.2'
@ -2136,6 +2261,16 @@ packages:
name: pytz
version: '2025.2'
sha256: 5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00
- pypi: https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz
name: pyyaml
version: 6.0.3
sha256: d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/25/a2/b725b61ac76a75583ae7104b3209f75ea44b13cfd026aa535ece22b7f22e/PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
name: pyyaml
version: 6.0.3
sha256: 22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6
requires_python: '>=3.8'
- pypi: https://files.pythonhosted.org/packages/3d/84/63b2e66f5c7cb97ce994769afbbef85a1ac364fedbcb7d4a3c0f15d318a5/rdkit-2024.3.5-cp38-cp38-manylinux_2_28_x86_64.whl
name: rdkit
version: 2024.3.5
@ -2554,6 +2689,82 @@ packages:
- docutils-stubs ; extra == 'lint'
- pytest ; extra == 'test'
requires_python: '>=3.5'
- pypi: https://files.pythonhosted.org/packages/07/d4/76b9618d7eb1e6a3c26734e1186f8ad7869e4426b1ea7dc425bc4c832e67/sqlalchemy-2.0.46-cp38-cp38-macosx_11_0_arm64.whl
name: sqlalchemy
version: 2.0.46
sha256: 6ac245604295b521de49b465bab845e3afe6916bcb2147e5929c8041b4ec0545
requires_dist:
- typing-extensions>=4.6.0
- greenlet>=1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
- importlib-metadata ; python_full_version < '3.8'
- greenlet>=1 ; extra == 'aiomysql'
- aiomysql>=0.2.0 ; extra == 'aiomysql'
- greenlet>=1 ; extra == 'aioodbc'
- aioodbc ; extra == 'aioodbc'
- greenlet>=1 ; extra == 'aiosqlite'
- aiosqlite ; extra == 'aiosqlite'
- typing-extensions!=3.10.0.1 ; extra == 'aiosqlite'
- greenlet>=1 ; extra == 'asyncio'
- greenlet>=1 ; extra == 'asyncmy'
- asyncmy>=0.2.3,!=0.2.4,!=0.2.6 ; extra == 'asyncmy'
- mariadb>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10 ; extra == 'mariadb-connector'
- pyodbc ; extra == 'mssql'
- pymssql ; extra == 'mssql-pymssql'
- pyodbc ; extra == 'mssql-pyodbc'
- mypy>=0.910 ; extra == 'mypy'
- mysqlclient>=1.4.0 ; extra == 'mysql'
- mysql-connector-python ; extra == 'mysql-connector'
- cx-oracle>=8 ; extra == 'oracle'
- oracledb>=1.0.1 ; extra == 'oracle-oracledb'
- psycopg2>=2.7 ; extra == 'postgresql'
- greenlet>=1 ; extra == 'postgresql-asyncpg'
- asyncpg ; extra == 'postgresql-asyncpg'
- pg8000>=1.29.1 ; extra == 'postgresql-pg8000'
- psycopg>=3.0.7 ; extra == 'postgresql-psycopg'
- psycopg2-binary ; extra == 'postgresql-psycopg2binary'
- psycopg2cffi ; extra == 'postgresql-psycopg2cffi'
- psycopg[binary]>=3.0.7 ; extra == 'postgresql-psycopgbinary'
- pymysql ; extra == 'pymysql'
- sqlcipher3-binary ; extra == 'sqlcipher'
requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/8a/7b/b9e0eb7f9a15f2e82856603c728edf14c54a07c6738ab228e4f2de049338/sqlalchemy-2.0.46-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
name: sqlalchemy
version: 2.0.46
sha256: 716be5bcabf327b6d5d265dbdc6213a01199be587224eb991ad0d37e83d728fd
requires_dist:
- typing-extensions>=4.6.0
- greenlet>=1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
- importlib-metadata ; python_full_version < '3.8'
- greenlet>=1 ; extra == 'aiomysql'
- aiomysql>=0.2.0 ; extra == 'aiomysql'
- greenlet>=1 ; extra == 'aioodbc'
- aioodbc ; extra == 'aioodbc'
- greenlet>=1 ; extra == 'aiosqlite'
- aiosqlite ; extra == 'aiosqlite'
- typing-extensions!=3.10.0.1 ; extra == 'aiosqlite'
- greenlet>=1 ; extra == 'asyncio'
- greenlet>=1 ; extra == 'asyncmy'
- asyncmy>=0.2.3,!=0.2.4,!=0.2.6 ; extra == 'asyncmy'
- mariadb>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10 ; extra == 'mariadb-connector'
- pyodbc ; extra == 'mssql'
- pymssql ; extra == 'mssql-pymssql'
- pyodbc ; extra == 'mssql-pyodbc'
- mypy>=0.910 ; extra == 'mypy'
- mysqlclient>=1.4.0 ; extra == 'mysql'
- mysql-connector-python ; extra == 'mysql-connector'
- cx-oracle>=8 ; extra == 'oracle'
- oracledb>=1.0.1 ; extra == 'oracle-oracledb'
- psycopg2>=2.7 ; extra == 'postgresql'
- greenlet>=1 ; extra == 'postgresql-asyncpg'
- asyncpg ; extra == 'postgresql-asyncpg'
- pg8000>=1.29.1 ; extra == 'postgresql-pg8000'
- psycopg>=3.0.7 ; extra == 'postgresql-psycopg'
- psycopg2-binary ; extra == 'postgresql-psycopg2binary'
- psycopg2cffi ; extra == 'postgresql-psycopg2cffi'
- psycopg[binary]>=3.0.7 ; extra == 'postgresql-psycopgbinary'
- pymysql ; extra == 'pymysql'
- sqlcipher3-binary ; extra == 'sqlcipher'
requires_python: '>=3.7'
- pypi: https://files.pythonhosted.org/packages/b6/c5/7ae467eeddb57260c8ce17a3a09f9f5edba35820fc022d7c55b7decd5d3a/starlette-0.44.0-py3-none-any.whl
name: starlette
version: 0.44.0

View File

@ -28,3 +28,4 @@ fastapi = ">=0.124.4, <0.125"
streamlit = ">=1.40.1, <2"
httpx = ">=0.28.1, <0.29"
uvicorn = ">=0.33.0, <0.34"
optuna = ">=4.5.0, <5"

21
requirements.txt Normal file
View File

@ -0,0 +1,21 @@
# 严格遵循 pixi.toml 的依赖版本
# 注意: lnp_ml 本地包在 Dockerfile 中单独安装
# conda dependencies (in pixi.toml [dependencies])
loguru
tqdm
typer
# pypi dependencies (in pixi.toml [pypi-dependencies])
chemprop==1.7.0
setuptools
pandas>=2.0.3,<3
openpyxl>=3.1.5,<4
python-dotenv>=1.0.1,<2
pyarrow>=17.0.0,<18
fastparquet>=2024.2.0,<2025
fastapi>=0.124.4,<0.125
streamlit>=1.40.1,<2
httpx>=0.28.1,<0.29
uvicorn>=0.33.0,<0.34
optuna>=4.5.0,<5