This commit is contained in:
RYDE-WORK 2026-01-22 01:01:29 +08:00
parent e6a5e5495a
commit 039be54c5a
33 changed files with 3654 additions and 6 deletions

View File

@ -78,6 +78,11 @@ data_pretrain: requirements
data_pretrain_cv: requirements
$(PYTHON_INTERPRETER) scripts/process_external_cv.py
## Process internal data with amine-based CV splitting (interim -> processed/cv)
.PHONY: data_cv
data_cv: requirements
$(PYTHON_INTERPRETER) scripts/process_data_cv.py
# MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder
# 例如make pretrain USE_MPNN=1
MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,)
@ -120,6 +125,16 @@ train: requirements
finetune: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
## Finetune with cross-validation on internal data (5-fold, amine-based split) with pretrained weights
.PHONY: finetune_cv
finetune_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv main --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
## Evaluate CV finetuned models on test sets (auto-detects MPNN from checkpoint)
.PHONY: test_cv
test_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv test $(DEVICE_FLAG)
## Train with hyperparameter tuning
.PHONY: tune
tune: requirements

View File

@ -0,0 +1,89 @@
# Feature columns configuration
# SMILES
smiles
# comp token [5]
Cationic_Lipid_to_mRNA_weight_ratio
Cationic_Lipid_Mol_Ratio
Phospholipid_Mol_Ratio
Cholesterol_Mol_Ratio
PEG_Lipid_Mol_Ratio
# phys token [12]
Purity_Pure
Purity_Crude
Mix_type_Microfluidic
Mix_type_Pipetting
Cargo_type_mRNA
Cargo_type_pDNA
Cargo_type_siRNA
Target_or_delivered_gene_FFL
Target_or_delivered_gene_Peptide_barcode
Target_or_delivered_gene_hEPO
Target_or_delivered_gene_FVII
Target_or_delivered_gene_GFP
# help token [4]
Helper_lipid_ID_DOPE
Helper_lipid_ID_DOTAP
Helper_lipid_ID_DSPC
Helper_lipid_ID_MDOA
# exp token [32]
Model_type_A549
Model_type_BDMC
Model_type_BMDM
Model_type_HBEC_ALI
Model_type_HEK293T
Model_type_HeLa
Model_type_IGROV1
Model_type_Mouse
Model_type_RAW264p7
Delivery_target_body
Delivery_target_dendritic_cell
Delivery_target_generic_cell
Delivery_target_liver
Delivery_target_lung
Delivery_target_lung_epithelium
Delivery_target_macrophage
Delivery_target_muscle
Delivery_target_spleen
Route_of_administration_in_vitro
Route_of_administration_intramuscular
Route_of_administration_intratracheal
Route_of_administration_intravenous
Batch_or_individual_or_barcoded_Barcoded
Batch_or_individual_or_barcoded_Individual
Value_name_log_luminescence
Value_name_luminescence
Value_name_FFL_silencing
Value_name_Peptide_abundance
Value_name_hEPO
Value_name_FVII_silencing
Value_name_GFP_delivery
Value_name_Discretized_luminescence
# Targets
## Regression
size
quantified_delivery
## PDI classification
PDI_0_0to0_2
PDI_0_2to0_3
PDI_0_3to0_4
PDI_0_4to0_5
## EE classification
Encapsulation_Efficiency_EE<50
Encapsulation_Efficiency_50<=EE<80
Encapsulation_Efficiency_80<EE<=100
## Toxic
toxic
## Biodistribution
Biodistribution_lymph_nodes
Biodistribution_heart
Biodistribution_liver
Biodistribution_spleen
Biodistribution_lung
Biodistribution_kidney
Biodistribution_muscle

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

647
lnp_ml/modeling/train_cv.py Normal file
View File

@ -0,0 +1,647 @@
"""Cross-Validation 训练脚本:在 5-fold 内部数据上进行多任务训练"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from loguru import logger
from tqdm import tqdm
import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import LNPDataset, collate_fn
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
from lnp_ml.modeling.trainer import (
train_epoch,
validate,
EarlyStopping,
LossWeights,
)
# MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
if not model_paths:
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
return [str(p) for p in model_paths]
app = typer.Typer()
def create_model(
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
mpnn_checkpoint: Optional[str] = None,
mpnn_ensemble_paths: Optional[List[str]] = None,
mpnn_device: str = "cpu",
) -> Union[LNPModel, LNPModelWithoutMPNN]:
"""创建模型(支持可选的 MPNN encoder"""
use_mpnn = mpnn_checkpoint is not None or mpnn_ensemble_paths is not None
if use_mpnn:
return LNPModel(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_checkpoint=mpnn_checkpoint,
mpnn_ensemble_paths=mpnn_ensemble_paths,
mpnn_device=mpnn_device,
)
else:
return LNPModelWithoutMPNN(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
)
def train_fold(
fold_idx: int,
train_loader: DataLoader,
val_loader: DataLoader,
model: nn.Module,
device: torch.device,
output_dir: Path,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 100,
patience: int = 15,
loss_weights: Optional[LossWeights] = None,
config: Optional[Dict] = None,
) -> Dict:
"""训练单个 fold"""
logger.info(f"\n{'='*60}")
logger.info(f"Training Fold {fold_idx}")
logger.info(f"{'='*60}")
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
early_stopping = EarlyStopping(patience=patience)
history = {"train": [], "val": []}
best_val_loss = float("inf")
best_state = None
for epoch in range(epochs):
# Train
train_metrics = train_epoch(model, train_loader, optimizer, device, loss_weights)
# Validate
val_metrics = validate(model, val_loader, device, loss_weights)
current_lr = optimizer.param_groups[0]["lr"]
# Log
logger.info(
f"Fold {fold_idx} Epoch {epoch+1}/{epochs} | "
f"Train Loss: {train_metrics['loss']:.4f} | "
f"Val Loss: {val_metrics['loss']:.4f} | "
f"LR: {current_lr:.2e}"
)
history["train"].append(train_metrics)
history["val"].append(val_metrics)
# Learning rate scheduling
scheduler.step(val_metrics["loss"])
# Save best model
if val_metrics["loss"] < best_val_loss:
best_val_loss = val_metrics["loss"]
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
logger.info(f" -> New best model (val_loss={best_val_loss:.4f})")
# Early stopping
if early_stopping(val_metrics["loss"]):
logger.info(f"Early stopping at epoch {epoch+1}")
break
# 保存最佳模型
fold_output_dir = output_dir / f"fold_{fold_idx}"
fold_output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = fold_output_dir / "model.pt"
torch.save({
"model_state_dict": best_state,
"config": config,
"best_val_loss": best_val_loss,
"fold_idx": fold_idx,
}, checkpoint_path)
logger.success(f"Saved fold {fold_idx} model to {checkpoint_path}")
# 保存训练历史
history_path = fold_output_dir / "history.json"
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
return {
"fold_idx": fold_idx,
"best_val_loss": best_val_loss,
"epochs_trained": len(history["train"]),
"final_train_loss": history["train"][-1]["loss"] if history["train"] else 0,
}
@app.command()
def main(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
output_dir: Path = MODELS_DIR / "finetune_cv",
# 模型参数
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
# MPNN 参数(可选)
use_mpnn: bool = False,
mpnn_checkpoint: Optional[str] = None,
mpnn_ensemble_paths: Optional[str] = None,
mpnn_device: str = "cpu",
# 训练参数
batch_size: int = 32,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 100,
patience: int = 15,
# 预训练权重加载
init_from_pretrain: Optional[Path] = None,
load_delivery_head: bool = True,
freeze_backbone: bool = False,
# 设备
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
基于 Cross-Validation 训练 LNP 模型多任务
5-fold 内部数据上训练 5 个模型
使用 --use-mpnn 启用 MPNN encoder
使用 --init-from-pretrain 从预训练 checkpoint 初始化
使用 --freeze-backbone 冻结 backbone只训练 heads
"""
logger.info(f"Using device: {device}")
device = torch.device(device)
# 查找所有 fold 目录
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
logger.info("Please run 'make data_cv' first to process CV data.")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
output_dir.mkdir(parents=True, exist_ok=True)
# 解析 MPNN 配置
ensemble_paths_list = None
if mpnn_ensemble_paths:
ensemble_paths_list = mpnn_ensemble_paths.split(",")
elif use_mpnn and mpnn_checkpoint is None:
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
ensemble_paths_list = find_mpnn_ensemble_paths()
logger.info(f"Found {len(ensemble_paths_list)} MPNN models")
enable_mpnn = mpnn_checkpoint is not None or ensemble_paths_list is not None
# 模型配置
config = {
"d_model": d_model,
"num_heads": num_heads,
"n_attn_layers": n_attn_layers,
"fusion_strategy": fusion_strategy,
"head_hidden_dim": head_hidden_dim,
"dropout": dropout,
"use_mpnn": enable_mpnn,
"lr": lr,
"weight_decay": weight_decay,
"batch_size": batch_size,
"epochs": epochs,
"patience": patience,
"init_from_pretrain": str(init_from_pretrain) if init_from_pretrain else None,
"freeze_backbone": freeze_backbone,
}
# 保存配置
config_path = output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
logger.info(f"Saved config to {config_path}")
# 加载预训练权重(如果指定)
pretrain_state = None
if init_from_pretrain is not None:
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
pretrain_config = checkpoint.get("config", {})
if pretrain_config.get("d_model") != d_model:
logger.warning(
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
f"current={d_model}. Skipping pretrain loading."
)
else:
pretrain_state = checkpoint["model_state_dict"]
# 训练每个 fold
fold_results = []
for fold_dir in tqdm(fold_dirs, desc="Training folds"):
fold_idx = int(fold_dir.name.split("_")[1])
# 加载数据
train_df = pd.read_parquet(fold_dir / "train.parquet")
val_df = pd.read_parquet(fold_dir / "val.parquet")
logger.info(f"\nFold {fold_idx}: train={len(train_df)}, val={len(val_df)}")
# 创建 Dataset 和 DataLoader
train_dataset = LNPDataset(train_df)
val_dataset = LNPDataset(val_df)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 创建新模型(每个 fold 独立初始化)
model = create_model(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_checkpoint=mpnn_checkpoint,
mpnn_ensemble_paths=ensemble_paths_list,
mpnn_device=device.type,
)
# 加载预训练权重
if pretrain_state is not None:
model.load_pretrain_weights(
pretrain_state_dict=pretrain_state,
load_delivery_head=load_delivery_head,
strict=False,
)
logger.info(f"Loaded pretrain weights (backbone + delivery_head={load_delivery_head})")
# 冻结 backbone如果指定
if freeze_backbone:
frozen_count = 0
for name, param in model.named_parameters():
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
param.requires_grad = False
frozen_count += 1
logger.info(f"Frozen {frozen_count} parameter tensors")
# 打印模型信息(仅第一个 fold
if fold_idx == 0:
n_params_total = sum(p.numel() for p in model.parameters())
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
# 训练
result = train_fold(
fold_idx=fold_idx,
train_loader=train_loader,
val_loader=val_loader,
model=model,
device=device,
output_dir=output_dir,
lr=lr,
weight_decay=weight_decay,
epochs=epochs,
patience=patience,
config=config,
)
fold_results.append(result)
# 汇总结果
logger.info("\n" + "=" * 60)
logger.info("CROSS-VALIDATION TRAINING COMPLETE")
logger.info("=" * 60)
val_losses = [r["best_val_loss"] for r in fold_results]
logger.info(f"\n[Per-Fold Results]")
for r in fold_results:
logger.info(
f" Fold {r['fold_idx']}: "
f"Val Loss={r['best_val_loss']:.4f}, "
f"Epochs={r['epochs_trained']}"
)
logger.info(f"\n[Summary Statistics]")
logger.info(f" Val Loss: {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}")
# 保存 CV 结果
cv_results = {
"fold_results": fold_results,
"summary": {
"val_loss_mean": float(np.mean(val_losses)),
"val_loss_std": float(np.std(val_losses)),
},
"config": config,
}
results_path = output_dir / "cv_results.json"
with open(results_path, "w") as f:
json.dump(cv_results, f, indent=2)
logger.success(f"Saved CV results to {results_path}")
@app.command()
def test(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
model_dir: Path = MODELS_DIR / "finetune_cv",
output_path: Path = MODELS_DIR / "finetune_cv" / "test_results.json",
batch_size: int = 64,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
在测试集上评估 CV 训练的模型
使用每个 fold 的模型在对应的测试集上评估然后汇总结果
"""
from sklearn.metrics import (
mean_squared_error,
mean_absolute_error,
r2_score,
accuracy_score,
)
logger.info(f"Using device: {device}")
device = torch.device(device)
# 查找所有 fold 目录
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds")
fold_results = []
# 用于汇总所有 fold 的预测
all_preds = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": []
}
all_targets = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": []
}
for fold_dir in tqdm(fold_dirs, desc="Evaluating folds"):
fold_idx = int(fold_dir.name.split("_")[1])
model_path = model_dir / f"fold_{fold_idx}" / "model.pt"
test_path = fold_dir / "test.parquet"
if not model_path.exists():
logger.warning(f"Fold {fold_idx}: model not found at {model_path}, skipping")
continue
if not test_path.exists():
logger.warning(f"Fold {fold_idx}: test data not found at {test_path}, skipping")
continue
# 加载模型
checkpoint = torch.load(model_path, map_location=device)
config = checkpoint["config"]
use_mpnn = config.get("use_mpnn", False)
# 总是重新查找 MPNN 路径
if use_mpnn:
mpnn_paths = find_mpnn_ensemble_paths()
else:
mpnn_paths = None
model = create_model(
d_model=config["d_model"],
num_heads=config["num_heads"],
n_attn_layers=config["n_attn_layers"],
fusion_strategy=config["fusion_strategy"],
head_hidden_dim=config["head_hidden_dim"],
dropout=config["dropout"],
mpnn_ensemble_paths=mpnn_paths,
mpnn_device=device.type,
)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
model.eval()
# 加载测试数据
test_df = pd.read_parquet(test_path)
test_dataset = LNPDataset(test_df)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 收集当前 fold 的预测
fold_preds = {k: [] for k in all_preds.keys()}
fold_targets = {k: [] for k in all_targets.keys()}
with torch.no_grad():
pbar = tqdm(test_loader, desc=f"Fold {fold_idx} [Test]", leave=False)
for batch in pbar:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = batch["targets"]
masks = batch["mask"]
outputs = model(smiles, tabular)
# Size
if "size" in masks and masks["size"].any():
mask = masks["size"]
fold_preds["size"].extend(
outputs["size"].squeeze(-1)[mask].cpu().numpy().tolist()
)
fold_targets["size"].extend(
targets["size"][mask].cpu().numpy().tolist()
)
# Delivery
if "delivery" in masks and masks["delivery"].any():
mask = masks["delivery"]
fold_preds["delivery"].extend(
outputs["delivery"].squeeze(-1)[mask].cpu().numpy().tolist()
)
fold_targets["delivery"].extend(
targets["delivery"][mask].cpu().numpy().tolist()
)
# PDI (classification)
if "pdi" in masks and masks["pdi"].any():
mask = masks["pdi"]
pdi_preds = outputs["pdi"][mask].argmax(dim=-1).cpu().numpy()
pdi_targets = targets["pdi"][mask].cpu().numpy()
fold_preds["pdi"].extend(pdi_preds.tolist())
fold_targets["pdi"].extend(pdi_targets.tolist())
# EE (classification)
if "ee" in masks and masks["ee"].any():
mask = masks["ee"]
ee_preds = outputs["ee"][mask].argmax(dim=-1).cpu().numpy()
ee_targets = targets["ee"][mask].cpu().numpy()
fold_preds["ee"].extend(ee_preds.tolist())
fold_targets["ee"].extend(ee_targets.tolist())
# Toxic (classification)
if "toxic" in masks and masks["toxic"].any():
mask = masks["toxic"]
toxic_preds = outputs["toxic"][mask].argmax(dim=-1).cpu().numpy()
toxic_targets = targets["toxic"][mask].cpu().numpy().astype(int)
fold_preds["toxic"].extend(toxic_preds.tolist())
fold_targets["toxic"].extend(toxic_targets.tolist())
# 计算当前 fold 的指标
fold_metrics = {"fold_idx": fold_idx, "n_samples": len(test_df)}
# 回归任务指标
for task in ["size", "delivery"]:
if fold_preds[task]:
p = np.array(fold_preds[task])
t = np.array(fold_targets[task])
fold_metrics[task] = {
"n": len(p),
"rmse": float(np.sqrt(mean_squared_error(t, p))),
"mae": float(mean_absolute_error(t, p)),
"r2": float(r2_score(t, p)),
}
# 分类任务指标
for task in ["pdi", "ee", "toxic"]:
if fold_preds[task]:
p = np.array(fold_preds[task])
t = np.array(fold_targets[task])
fold_metrics[task] = {
"n": len(p),
"accuracy": float(accuracy_score(t, p)),
}
fold_results.append(fold_metrics)
# 汇总到全局
for task in all_preds.keys():
all_preds[task].extend(fold_preds[task])
all_targets[task].extend(fold_targets[task])
# 打印当前 fold 结果
log_parts = [f"Fold {fold_idx}: n={len(test_df)}"]
for task in ["delivery", "size"]:
if task in fold_metrics and isinstance(fold_metrics[task], dict):
log_parts.append(f"{task}_RMSE={fold_metrics[task]['rmse']:.4f}")
log_parts.append(f"{task}_R²={fold_metrics[task]['r2']:.4f}")
for task in ["pdi", "ee", "toxic"]:
if task in fold_metrics and isinstance(fold_metrics[task], dict):
log_parts.append(f"{task}_acc={fold_metrics[task]['accuracy']:.4f}")
logger.info(", ".join(log_parts))
# 计算跨 fold 汇总统计
summary_stats = {}
for task in ["size", "delivery"]:
rmses = [r[task]["rmse"] for r in fold_results if task in r and isinstance(r[task], dict)]
r2s = [r[task]["r2"] for r in fold_results if task in r and isinstance(r[task], dict)]
if rmses:
summary_stats[task] = {
"rmse_mean": float(np.mean(rmses)),
"rmse_std": float(np.std(rmses)),
"r2_mean": float(np.mean(r2s)),
"r2_std": float(np.std(r2s)),
}
for task in ["pdi", "ee", "toxic"]:
accs = [r[task]["accuracy"] for r in fold_results if task in r and isinstance(r[task], dict)]
if accs:
summary_stats[task] = {
"accuracy_mean": float(np.mean(accs)),
"accuracy_std": float(np.std(accs)),
}
# 计算整体 pooled 指标
overall = {}
for task in ["size", "delivery"]:
if all_preds[task]:
p = np.array(all_preds[task])
t = np.array(all_targets[task])
overall[task] = {
"n_samples": len(p),
"mse": float(mean_squared_error(t, p)),
"rmse": float(np.sqrt(mean_squared_error(t, p))),
"mae": float(mean_absolute_error(t, p)),
"r2": float(r2_score(t, p)),
}
for task in ["pdi", "ee", "toxic"]:
if all_preds[task]:
p = np.array(all_preds[task])
t = np.array(all_targets[task])
overall[task] = {
"n_samples": len(p),
"accuracy": float(accuracy_score(t, p)),
}
# 打印汇总结果
logger.info("\n" + "=" * 60)
logger.info("CV TEST EVALUATION RESULTS")
logger.info("=" * 60)
logger.info(f"\n[Summary Statistics (across {len(fold_results)} folds)]")
for task, stats in summary_stats.items():
if "rmse_mean" in stats:
logger.info(f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}")
else:
logger.info(f" {task}: Accuracy={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}")
logger.info(f"\n[Overall (all samples pooled)]")
for task, metrics in overall.items():
if "rmse" in metrics:
logger.info(f" {task} (n={metrics['n_samples']}): RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, R²={metrics['r2']:.4f}")
else:
logger.info(f" {task} (n={metrics['n_samples']}): Accuracy={metrics['accuracy']:.4f}")
# 保存结果
results = {
"fold_results": fold_results,
"summary_stats": summary_stats,
"overall": overall,
}
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
logger.success(f"\nSaved test results to {output_path}")
if __name__ == "__main__":
app()

View File

@ -0,0 +1,16 @@
{
"d_model": 256,
"num_heads": 8,
"n_attn_layers": 4,
"fusion_strategy": "attention",
"head_hidden_dim": 128,
"dropout": 0.1,
"use_mpnn": true,
"lr": 0.0001,
"weight_decay": 1e-05,
"batch_size": 32,
"epochs": 100,
"patience": 15,
"init_from_pretrain": "models/pretrain_delivery.pt",
"freeze_backbone": false
}

View File

@ -0,0 +1,54 @@
{
"fold_results": [
{
"fold_idx": 0,
"best_val_loss": 6.144314289093018,
"epochs_trained": 25,
"final_train_loss": 1.4692220211029052
},
{
"fold_idx": 1,
"best_val_loss": 8.569346030553183,
"epochs_trained": 20,
"final_train_loss": 1.5929443359375
},
{
"fold_idx": 2,
"best_val_loss": 3.7409281730651855,
"epochs_trained": 22,
"final_train_loss": 1.9401288827260335
},
{
"fold_idx": 3,
"best_val_loss": 3.47284197807312,
"epochs_trained": 27,
"final_train_loss": 1.8295514345169068
},
{
"fold_idx": 4,
"best_val_loss": 2.756531000137329,
"epochs_trained": 19,
"final_train_loss": 1.9399811571294612
}
],
"summary": {
"val_loss_mean": 4.936792294184367,
"val_loss_std": 2.1438440638412697
},
"config": {
"d_model": 256,
"num_heads": 8,
"n_attn_layers": 4,
"fusion_strategy": "attention",
"head_hidden_dim": 128,
"dropout": 0.1,
"use_mpnn": true,
"lr": 0.0001,
"weight_decay": 1e-05,
"batch_size": 32,
"epochs": 100,
"patience": 15,
"init_from_pretrain": "models/pretrain_delivery.pt",
"freeze_backbone": false
}
}

View File

@ -0,0 +1,531 @@
{
"train": [
{
"loss": 19.238220310211183,
"loss_size": 14.334759521484376,
"loss_pdi": 1.275341796875,
"loss_ee": 1.078886091709137,
"loss_delivery": 0.639056247472763,
"loss_biodist": 1.314099133014679,
"loss_toxic": 0.5960778951644897
},
{
"loss": 7.8008105754852295,
"loss_size": 3.835961139202118,
"loss_pdi": 1.0304630517959594,
"loss_ee": 1.002296370267868,
"loss_delivery": 0.527982234954834,
"loss_biodist": 1.0791441202163696,
"loss_toxic": 0.3249638095498085
},
{
"loss": 3.952784705162048,
"loss_size": 0.5930101618170738,
"loss_pdi": 0.7961886763572693,
"loss_ee": 0.9416749358177186,
"loss_delivery": 0.5073600560426712,
"loss_biodist": 0.9171513140201568,
"loss_toxic": 0.19739954844117164
},
{
"loss": 3.218132185935974,
"loss_size": 0.20842453986406326,
"loss_pdi": 0.688220864534378,
"loss_ee": 0.904784232378006,
"loss_delivery": 0.4910900041460991,
"loss_biodist": 0.792213362455368,
"loss_toxic": 0.1333992186933756
},
{
"loss": 2.930907416343689,
"loss_size": 0.21291286423802375,
"loss_pdi": 0.6122969090938568,
"loss_ee": 0.8774014472961426,
"loss_delivery": 0.4451231583952904,
"loss_biodist": 0.6634868443012237,
"loss_toxic": 0.11968618221580982
},
{
"loss": 2.7193881273269653,
"loss_size": 0.213371854275465,
"loss_pdi": 0.5864351749420166,
"loss_ee": 0.8563571333885193,
"loss_delivery": 0.3963193610310555,
"loss_biodist": 0.5644777715206146,
"loss_toxic": 0.10242685079574584
},
{
"loss": 2.5172106266021728,
"loss_size": 0.23241330087184905,
"loss_pdi": 0.5564007371664047,
"loss_ee": 0.8135433554649353,
"loss_delivery": 0.39191135168075564,
"loss_biodist": 0.4503189116716385,
"loss_toxic": 0.07262293715029955
},
{
"loss": 2.3014568209648134,
"loss_size": 0.1924597330391407,
"loss_pdi": 0.5315678030252456,
"loss_ee": 0.8107137799263,
"loss_delivery": 0.33130097687244414,
"loss_biodist": 0.3714979439973831,
"loss_toxic": 0.06391655802726745
},
{
"loss": 2.1527106881141664,
"loss_size": 0.19257416054606438,
"loss_pdi": 0.5191590428352356,
"loss_ee": 0.783897054195404,
"loss_delivery": 0.29573799669742584,
"loss_biodist": 0.3145490542054176,
"loss_toxic": 0.046793402079492806
},
{
"loss": 2.0622685074806215,
"loss_size": 0.2051038146018982,
"loss_pdi": 0.49145313203334806,
"loss_ee": 0.7647641122341156,
"loss_delivery": 0.2876307189464569,
"loss_biodist": 0.27712231278419497,
"loss_toxic": 0.03619444826617837
},
{
"loss": 1.9519578456878661,
"loss_size": 0.17994399815797807,
"loss_pdi": 0.4814375311136246,
"loss_ee": 0.733842009305954,
"loss_delivery": 0.28253656476736067,
"loss_biodist": 0.24782671630382538,
"loss_toxic": 0.026371066551655532
},
{
"loss": 1.935675847530365,
"loss_size": 0.1704096481204033,
"loss_pdi": 0.47338791787624357,
"loss_ee": 0.7182988226413727,
"loss_delivery": 0.3093330509960651,
"loss_biodist": 0.2340244770050049,
"loss_toxic": 0.030221952823922038
},
{
"loss": 1.888454306125641,
"loss_size": 0.17727438509464263,
"loss_pdi": 0.46344051957130433,
"loss_ee": 0.7103636503219605,
"loss_delivery": 0.30027762055397034,
"loss_biodist": 0.2190815806388855,
"loss_toxic": 0.018016549991443753
},
{
"loss": 1.8231052160263062,
"loss_size": 0.1548917345702648,
"loss_pdi": 0.4576862633228302,
"loss_ee": 0.7034903109073639,
"loss_delivery": 0.29063438922166823,
"loss_biodist": 0.19972888082265855,
"loss_toxic": 0.016673638485372066
},
{
"loss": 1.756770372390747,
"loss_size": 0.15216425359249114,
"loss_pdi": 0.429460334777832,
"loss_ee": 0.6776757568120957,
"loss_delivery": 0.2867794781923294,
"loss_biodist": 0.19385820478200913,
"loss_toxic": 0.01683232020586729
},
{
"loss": 1.7031532883644105,
"loss_size": 0.15495768785476685,
"loss_pdi": 0.43094243109226227,
"loss_ee": 0.6677232623100281,
"loss_delivery": 0.26152765452861787,
"loss_biodist": 0.17694738060235976,
"loss_toxic": 0.011054831324145198
},
{
"loss": 1.679247748851776,
"loss_size": 0.15252191424369813,
"loss_pdi": 0.40398688316345216,
"loss_ee": 0.6563315153121948,
"loss_delivery": 0.2827941685914993,
"loss_biodist": 0.17421896755695343,
"loss_toxic": 0.009394321008585393
},
{
"loss": 1.6231786251068114,
"loss_size": 0.14685654938220977,
"loss_pdi": 0.3987069964408875,
"loss_ee": 0.6459777146577835,
"loss_delivery": 0.2517095260322094,
"loss_biodist": 0.1695146732032299,
"loss_toxic": 0.010413196869194508
},
{
"loss": 1.5647669196128846,
"loss_size": 0.12480136081576347,
"loss_pdi": 0.40768158435821533,
"loss_ee": 0.6228045016527176,
"loss_delivery": 0.23313914462924004,
"loss_biodist": 0.1658004455268383,
"loss_toxic": 0.0105398821644485
},
{
"loss": 1.543732750415802,
"loss_size": 0.12116749435663224,
"loss_pdi": 0.3942162901163101,
"loss_ee": 0.6175289869308471,
"loss_delivery": 0.2506958607584238,
"loss_biodist": 0.15138662829995156,
"loss_toxic": 0.008737476798705757
},
{
"loss": 1.534558892250061,
"loss_size": 0.11713396161794662,
"loss_pdi": 0.39247085303068163,
"loss_ee": 0.608239871263504,
"loss_delivery": 0.24818528145551683,
"loss_biodist": 0.15680191665887833,
"loss_toxic": 0.011727035511285067
},
{
"loss": 1.508529245853424,
"loss_size": 0.12370343580842018,
"loss_pdi": 0.37536335587501524,
"loss_ee": 0.5983373761177063,
"loss_delivery": 0.2428302437067032,
"loss_biodist": 0.16056786775588988,
"loss_toxic": 0.0077269634697586295
},
{
"loss": 1.495661163330078,
"loss_size": 0.12558167055249214,
"loss_pdi": 0.3853023111820221,
"loss_ee": 0.5909300297498703,
"loss_delivery": 0.22781415805220603,
"loss_biodist": 0.15893371179699897,
"loss_toxic": 0.007099285908043385
},
{
"loss": 1.4543131768703461,
"loss_size": 0.11276061460375786,
"loss_pdi": 0.3703819438815117,
"loss_ee": 0.5739975512027741,
"loss_delivery": 0.24190589040517807,
"loss_biodist": 0.14822293519973756,
"loss_toxic": 0.007044258969835937
},
{
"loss": 1.4692220211029052,
"loss_size": 0.12677552737295628,
"loss_pdi": 0.37003436386585237,
"loss_ee": 0.578819090127945,
"loss_delivery": 0.2372341684997082,
"loss_biodist": 0.1498504839837551,
"loss_toxic": 0.0065083671128377315
}
],
"val": [
{
"loss": 16.045034408569336,
"loss_size": 7.810531139373779,
"loss_pdi": 1.0906392335891724,
"loss_ee": 1.0121623277664185,
"loss_delivery": 4.599453926086426,
"loss_biodist": 1.121813416481018,
"loss_toxic": 0.4104346036911011,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
},
{
"loss": 8.368906021118164,
"loss_size": 1.3932932615280151,
"loss_pdi": 0.7484031915664673,
"loss_ee": 0.9992449283599854,
"loss_delivery": 3.9427361488342285,
"loss_biodist": 1.0493215322494507,
"loss_toxic": 0.2359071671962738,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.77295446395874,
"loss_size": 0.07529407739639282,
"loss_pdi": 0.5075266361236572,
"loss_ee": 0.9870877265930176,
"loss_delivery": 4.065003395080566,
"loss_biodist": 1.0452064275741577,
"loss_toxic": 0.09283570945262909,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 7.080923080444336,
"loss_size": 0.06525272130966187,
"loss_pdi": 0.4343451261520386,
"loss_ee": 1.0161277055740356,
"loss_delivery": 4.559023380279541,
"loss_biodist": 0.9523851871490479,
"loss_toxic": 0.05378875508904457,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.916051864624023,
"loss_size": 0.07201710343360901,
"loss_pdi": 0.4837420880794525,
"loss_ee": 1.0075831413269043,
"loss_delivery": 4.492982864379883,
"loss_biodist": 0.7963473796844482,
"loss_toxic": 0.0633791983127594,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.480402946472168,
"loss_size": 0.056911345571279526,
"loss_pdi": 0.4535364508628845,
"loss_ee": 1.00966477394104,
"loss_delivery": 4.184553146362305,
"loss_biodist": 0.6998977661132812,
"loss_toxic": 0.07583901286125183,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.857071399688721,
"loss_size": 0.04192754626274109,
"loss_pdi": 0.44999659061431885,
"loss_ee": 1.0004281997680664,
"loss_delivery": 4.647970676422119,
"loss_biodist": 0.6309637427330017,
"loss_toxic": 0.0857851505279541,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.3087968826293945,
"loss_size": 0.09944896399974823,
"loss_pdi": 0.40266019105911255,
"loss_ee": 0.9705824255943298,
"loss_delivery": 4.246408939361572,
"loss_biodist": 0.552162230014801,
"loss_toxic": 0.03753397986292839,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.9212870597839355,
"loss_size": 0.05665857717394829,
"loss_pdi": 0.4618251323699951,
"loss_ee": 0.9701217412948608,
"loss_delivery": 4.835049152374268,
"loss_biodist": 0.5396986603736877,
"loss_toxic": 0.057933975011110306,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.144314289093018,
"loss_size": 0.16736657917499542,
"loss_pdi": 0.43004411458969116,
"loss_ee": 0.9552963972091675,
"loss_delivery": 4.0627570152282715,
"loss_biodist": 0.5059604048728943,
"loss_toxic": 0.02288975566625595,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 7.03376579284668,
"loss_size": 0.11119232326745987,
"loss_pdi": 0.4463132619857788,
"loss_ee": 0.9212661385536194,
"loss_delivery": 5.010645866394043,
"loss_biodist": 0.5135870575904846,
"loss_toxic": 0.03076130710542202,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.251324653625488,
"loss_size": 0.16327497363090515,
"loss_pdi": 0.4076344966888428,
"loss_ee": 0.9357188940048218,
"loss_delivery": 4.216032981872559,
"loss_biodist": 0.5158465504646301,
"loss_toxic": 0.012816602364182472,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
},
{
"loss": 6.909743309020996,
"loss_size": 0.10885120183229446,
"loss_pdi": 0.40938591957092285,
"loss_ee": 0.8893271684646606,
"loss_delivery": 4.9792799949646,
"loss_biodist": 0.49506261944770813,
"loss_toxic": 0.027835864573717117,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
},
{
"loss": 6.724283695220947,
"loss_size": 0.10787779092788696,
"loss_pdi": 0.45569828152656555,
"loss_ee": 0.979951798915863,
"loss_delivery": 4.641026020050049,
"loss_biodist": 0.5102843642234802,
"loss_toxic": 0.029445137828588486,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.498813629150391,
"loss_size": 0.14617349207401276,
"loss_pdi": 0.3836137652397156,
"loss_ee": 0.8910271525382996,
"loss_delivery": 4.627362251281738,
"loss_biodist": 0.43996545672416687,
"loss_toxic": 0.010672268457710743,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.863435745239258,
"loss_size": 0.1379111111164093,
"loss_pdi": 0.5173991322517395,
"loss_ee": 1.013482689857483,
"loss_delivery": 4.7425923347473145,
"loss_biodist": 0.4353649318218231,
"loss_toxic": 0.01668536849319935,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.48148148148148145,
"acc_toxic": 1.0
},
{
"loss": 6.909954071044922,
"loss_size": 0.35534632205963135,
"loss_pdi": 0.3974299728870392,
"loss_ee": 0.8724083304405212,
"loss_delivery": 4.841543674468994,
"loss_biodist": 0.42888545989990234,
"loss_toxic": 0.014340322464704514,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
},
{
"loss": 6.718777179718018,
"loss_size": 0.10662573575973511,
"loss_pdi": 0.4000368118286133,
"loss_ee": 0.925907552242279,
"loss_delivery": 4.855783939361572,
"loss_biodist": 0.4174629747867584,
"loss_toxic": 0.012960433959960938,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.800571918487549,
"loss_size": 0.20526628196239471,
"loss_pdi": 0.42215225100517273,
"loss_ee": 0.9293471574783325,
"loss_delivery": 4.7998528480529785,
"loss_biodist": 0.42879122495651245,
"loss_toxic": 0.015162378549575806,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.873208522796631,
"loss_size": 0.1934625804424286,
"loss_pdi": 0.45854493975639343,
"loss_ee": 0.9394232630729675,
"loss_delivery": 4.830544471740723,
"loss_biodist": 0.43078526854515076,
"loss_toxic": 0.020448585972189903,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
},
{
"loss": 6.790691375732422,
"loss_size": 0.18806225061416626,
"loss_pdi": 0.43740272521972656,
"loss_ee": 0.9507966637611389,
"loss_delivery": 4.783614158630371,
"loss_biodist": 0.4198465943336487,
"loss_toxic": 0.010969163849949837,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.847814083099365,
"loss_size": 0.22717446088790894,
"loss_pdi": 0.42288023233413696,
"loss_ee": 0.9286114573478699,
"loss_delivery": 4.81552791595459,
"loss_biodist": 0.43471890687942505,
"loss_toxic": 0.01890140399336815,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 7.032270431518555,
"loss_size": 0.24087895452976227,
"loss_pdi": 0.46614596247673035,
"loss_ee": 0.9626513123512268,
"loss_delivery": 4.897539138793945,
"loss_biodist": 0.44117382168769836,
"loss_toxic": 0.023880867287516594,
"acc_pdi": 0.6666666666666666,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
},
{
"loss": 6.840210914611816,
"loss_size": 0.18546245992183685,
"loss_pdi": 0.41406312584877014,
"loss_ee": 0.9308509230613708,
"loss_delivery": 4.877482891082764,
"loss_biodist": 0.41710007190704346,
"loss_toxic": 0.015251458622515202,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
},
{
"loss": 6.980130672454834,
"loss_size": 0.21011953055858612,
"loss_pdi": 0.4370857775211334,
"loss_ee": 0.9217703938484192,
"loss_delivery": 4.956975936889648,
"loss_biodist": 0.4334105849266052,
"loss_toxic": 0.020767945796251297,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
}
]
}

Binary file not shown.

View File

@ -0,0 +1,426 @@
{
"train": [
{
"loss": 22.032494735717773,
"loss_size": 16.743953514099122,
"loss_pdi": 1.2239572763442994,
"loss_ee": 1.0474397182464599,
"loss_delivery": 1.2612280070781707,
"loss_biodist": 1.1997641563415526,
"loss_toxic": 0.5561524748802185
},
{
"loss": 12.498378562927247,
"loss_size": 8.05031099319458,
"loss_pdi": 1.0179082751274109,
"loss_ee": 0.9407736420631408,
"loss_delivery": 0.8306634247303009,
"loss_biodist": 1.160559320449829,
"loss_toxic": 0.4981633126735687
},
{
"loss": 6.852249622344971,
"loss_size": 2.8376791954040526,
"loss_pdi": 0.8515164494514466,
"loss_ee": 0.8604451298713685,
"loss_delivery": 0.7656278729438781,
"loss_biodist": 1.1427257061004639,
"loss_toxic": 0.39425510764122007
},
{
"loss": 4.40604076385498,
"loss_size": 0.6938439428806304,
"loss_pdi": 0.6752870678901672,
"loss_ee": 0.8080484986305236,
"loss_delivery": 0.8313614726066589,
"loss_biodist": 1.0828368663787842,
"loss_toxic": 0.31466284990310667
},
{
"loss": 3.4253986835479737,
"loss_size": 0.16804716736078262,
"loss_pdi": 0.5639804005622864,
"loss_ee": 0.6901269078254699,
"loss_delivery": 0.6912416338920593,
"loss_biodist": 1.0242016911506653,
"loss_toxic": 0.28780081272125246
},
{
"loss": 3.2362718105316164,
"loss_size": 0.1483217939734459,
"loss_pdi": 0.6221840143203735,
"loss_ee": 0.6781063556671143,
"loss_delivery": 0.6013748198747635,
"loss_biodist": 0.9614210963249207,
"loss_toxic": 0.22486359924077987
},
{
"loss": 3.1696151733398437,
"loss_size": 0.1722302332520485,
"loss_pdi": 0.48496800661087036,
"loss_ee": 0.6616590738296508,
"loss_delivery": 0.7679106175899506,
"loss_biodist": 0.9077507495880127,
"loss_toxic": 0.17509644329547883
},
{
"loss": 2.6617531299591066,
"loss_size": 0.16581893265247344,
"loss_pdi": 0.4746619284152985,
"loss_ee": 0.630395919084549,
"loss_delivery": 0.3917909190058708,
"loss_biodist": 0.852727723121643,
"loss_toxic": 0.1463577665388584
},
{
"loss": 2.4909090995788574,
"loss_size": 0.11310702562332153,
"loss_pdi": 0.43855146765708924,
"loss_ee": 0.5929172158241272,
"loss_delivery": 0.42325166761875155,
"loss_biodist": 0.7738024115562439,
"loss_toxic": 0.1492793083190918
},
{
"loss": 2.3516653537750245,
"loss_size": 0.21800988018512726,
"loss_pdi": 0.43570560216903687,
"loss_ee": 0.5719826459884644,
"loss_delivery": 0.3004884377121925,
"loss_biodist": 0.6908805131912231,
"loss_toxic": 0.1345983102917671
},
{
"loss": 2.1030978202819823,
"loss_size": 0.10880238711833953,
"loss_pdi": 0.40215051770210264,
"loss_ee": 0.5412053823471069,
"loss_delivery": 0.2829424023628235,
"loss_biodist": 0.651984566450119,
"loss_toxic": 0.11601254418492317
},
{
"loss": 1.994719886779785,
"loss_size": 0.08978431597352028,
"loss_pdi": 0.4142456531524658,
"loss_ee": 0.532235836982727,
"loss_delivery": 0.2747663021087646,
"loss_biodist": 0.5863585829734802,
"loss_toxic": 0.09732922576367856
},
{
"loss": 1.9875550031661988,
"loss_size": 0.13663371056318283,
"loss_pdi": 0.3811588704586029,
"loss_ee": 0.5030429780483245,
"loss_delivery": 0.2990179345011711,
"loss_biodist": 0.5580029547214508,
"loss_toxic": 0.10969849154353142
},
{
"loss": 1.9142925500869752,
"loss_size": 0.12668041437864302,
"loss_pdi": 0.4271237254142761,
"loss_ee": 0.5021511077880859,
"loss_delivery": 0.2249750167131424,
"loss_biodist": 0.5295105099678039,
"loss_toxic": 0.1038517564535141
},
{
"loss": 1.81255943775177,
"loss_size": 0.1435435712337494,
"loss_pdi": 0.3587616056203842,
"loss_ee": 0.5099679112434388,
"loss_delivery": 0.24120242595672609,
"loss_biodist": 0.48327294588088987,
"loss_toxic": 0.07581093870103359
},
{
"loss": 1.7264988660812377,
"loss_size": 0.1290398582816124,
"loss_pdi": 0.3462284058332443,
"loss_ee": 0.48116588592529297,
"loss_delivery": 0.22381204068660737,
"loss_biodist": 0.47818992137908933,
"loss_toxic": 0.06806278452277184
},
{
"loss": 1.6520193338394165,
"loss_size": 0.11937985867261887,
"loss_pdi": 0.34835702180862427,
"loss_ee": 0.4374507278203964,
"loss_delivery": 0.24701516777276994,
"loss_biodist": 0.43337869048118594,
"loss_toxic": 0.06643788442015648
},
{
"loss": 1.5883747339248657,
"loss_size": 0.09780682176351548,
"loss_pdi": 0.36214256286621094,
"loss_ee": 0.43477209806442263,
"loss_delivery": 0.23339744359254838,
"loss_biodist": 0.4091412305831909,
"loss_toxic": 0.051114612445235255
},
{
"loss": 1.6899895429611207,
"loss_size": 0.10551446527242661,
"loss_pdi": 0.3761424541473389,
"loss_ee": 0.46849397420883176,
"loss_delivery": 0.27737232744693757,
"loss_biodist": 0.4076103329658508,
"loss_toxic": 0.05485602542757988
},
{
"loss": 1.5929443359375,
"loss_size": 0.11571613550186158,
"loss_pdi": 0.352296257019043,
"loss_ee": 0.4432097375392914,
"loss_delivery": 0.2575030043721199,
"loss_biodist": 0.37044936418533325,
"loss_toxic": 0.05376977995038033
}
],
"val": [
{
"loss": 22.81237284342448,
"loss_size": 13.806465148925781,
"loss_pdi": 1.229181448618571,
"loss_ee": 1.027739703655243,
"loss_delivery": 5.068911023437977,
"loss_biodist": 1.1279457012812297,
"loss_toxic": 0.5521289308865865,
"acc_pdi": 0.6526315789473685,
"acc_ee": 0.6947368421052632,
"acc_toxic": 0.8939393939393939
},
{
"loss": 15.498573303222656,
"loss_size": 6.912943522135417,
"loss_pdi": 1.151400883992513,
"loss_ee": 0.9848727782567342,
"loss_delivery": 4.852191311617692,
"loss_biodist": 1.106157898902893,
"loss_toxic": 0.4910069803396861,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 10.799206574757894,
"loss_size": 2.5314422051111856,
"loss_pdi": 1.0411948959032695,
"loss_ee": 0.9279763499895731,
"loss_delivery": 4.807806923985481,
"loss_biodist": 1.0742538372675579,
"loss_toxic": 0.4165322283903758,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.901862303415934,
"loss_size": 0.7716620067755381,
"loss_pdi": 1.01361749569575,
"loss_ee": 0.8802629311879476,
"loss_delivery": 4.857094804445903,
"loss_biodist": 1.0347116986910503,
"loss_toxic": 0.34451337655385333,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.569346030553183,
"loss_size": 0.41188229247927666,
"loss_pdi": 1.0502839088439941,
"loss_ee": 0.8726372222105662,
"loss_delivery": 4.940298028290272,
"loss_biodist": 1.0011842250823975,
"loss_toxic": 0.29306065539518994,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.03537893295288,
"loss_size": 0.5666823834180832,
"loss_pdi": 1.1037296056747437,
"loss_ee": 0.9048542082309723,
"loss_delivery": 5.218828019996484,
"loss_biodist": 0.9743464191754659,
"loss_toxic": 0.2669379909833272,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.051881631215414,
"loss_size": 0.5067511014640331,
"loss_pdi": 1.0797988673051198,
"loss_ee": 0.8918277323246002,
"loss_delivery": 5.376374647021294,
"loss_biodist": 0.9581284721692404,
"loss_toxic": 0.23900071531534195,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.965935707092285,
"loss_size": 0.4464708169301351,
"loss_pdi": 1.03824187318484,
"loss_ee": 0.8565650085608164,
"loss_delivery": 5.463352290292581,
"loss_biodist": 0.9425446192423502,
"loss_toxic": 0.21876097718874613,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.351192077000936,
"loss_size": 0.578731312106053,
"loss_pdi": 1.0086682240168254,
"loss_ee": 0.8167769958575567,
"loss_delivery": 5.8357385993003845,
"loss_biodist": 0.8966569105784098,
"loss_toxic": 0.21462025741736093,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.685971339543661,
"loss_size": 0.5851275982956091,
"loss_pdi": 0.9895619451999664,
"loss_ee": 0.8124870856602987,
"loss_delivery": 6.198981747031212,
"loss_biodist": 0.8602876861890157,
"loss_toxic": 0.2395249493420124,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.098868290583292,
"loss_size": 0.4912531226873398,
"loss_pdi": 0.9610133767127991,
"loss_ee": 0.8015020589033762,
"loss_delivery": 5.819371705253919,
"loss_biodist": 0.8127925594647726,
"loss_toxic": 0.21293580221633115,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.217247327168783,
"loss_size": 0.5547616928815842,
"loss_pdi": 0.9466870129108429,
"loss_ee": 0.80243648091952,
"loss_delivery": 5.907406737407048,
"loss_biodist": 0.7983624935150146,
"loss_toxic": 0.20759239544471106,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.307894190152487,
"loss_size": 0.6155262216925621,
"loss_pdi": 0.9505135516325632,
"loss_ee": 0.8020251393318176,
"loss_delivery": 5.944891105095546,
"loss_biodist": 0.7798070808251699,
"loss_toxic": 0.21513095125555992,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.434777816136679,
"loss_size": 0.5672734752297401,
"loss_pdi": 0.9817352195580801,
"loss_ee": 0.822968602180481,
"loss_delivery": 6.080470234155655,
"loss_biodist": 0.745076318581899,
"loss_toxic": 0.23725438863039017,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.355290253957113,
"loss_size": 0.5713960801561674,
"loss_pdi": 0.9858902891476949,
"loss_ee": 0.8192337850729624,
"loss_delivery": 6.011797075470288,
"loss_biodist": 0.724964420000712,
"loss_toxic": 0.24200843647122383,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.306842883427938,
"loss_size": 0.6075545425216357,
"loss_pdi": 0.9776454170544943,
"loss_ee": 0.7862655818462372,
"loss_delivery": 5.985511064529419,
"loss_biodist": 0.7076093653837839,
"loss_toxic": 0.24225737899541855,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.252086718877157,
"loss_size": 0.6051755348841349,
"loss_pdi": 0.9714606404304504,
"loss_ee": 0.7621750583251318,
"loss_delivery": 5.975689520438512,
"loss_biodist": 0.6939358512560526,
"loss_toxic": 0.24364992106954256,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.275835235913595,
"loss_size": 0.5751028036077818,
"loss_pdi": 0.9899951120217642,
"loss_ee": 0.768888125816981,
"loss_delivery": 5.988193516929944,
"loss_biodist": 0.6993262469768524,
"loss_toxic": 0.25432955101132393,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6842105263157895,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.215376615524292,
"loss_size": 0.5498512660463651,
"loss_pdi": 0.9988046189149221,
"loss_ee": 0.7566283419728279,
"loss_delivery": 5.962520445386569,
"loss_biodist": 0.6932495137055715,
"loss_toxic": 0.2543224884817998,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6947368421052632,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.17027481396993,
"loss_size": 0.5183901910980543,
"loss_pdi": 1.0107523500919342,
"loss_ee": 0.7627487083276113,
"loss_delivery": 5.940715471903483,
"loss_biodist": 0.6853441596031189,
"loss_toxic": 0.25232408692439395,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6947368421052632,
"acc_toxic": 0.8939393939393939
}
]
}

Binary file not shown.

View File

@ -0,0 +1,468 @@
{
"train": [
{
"loss": 23.856885592142742,
"loss_size": 17.954859733581543,
"loss_pdi": 1.3681265513102214,
"loss_ee": 1.068708558877309,
"loss_delivery": 1.6815399676561356,
"loss_biodist": 1.1385973294576008,
"loss_toxic": 0.6450536052385966
},
{
"loss": 15.18941100438436,
"loss_size": 10.012138843536377,
"loss_pdi": 1.2400256196657817,
"loss_ee": 0.9122016330560049,
"loss_delivery": 1.4215861509243648,
"loss_biodist": 1.0704677999019623,
"loss_toxic": 0.5329908380905787
},
{
"loss": 8.939830541610718,
"loss_size": 4.306200265884399,
"loss_pdi": 1.0988461375236511,
"loss_ee": 0.7973864078521729,
"loss_delivery": 1.312243824203809,
"loss_biodist": 0.9968815743923187,
"loss_toxic": 0.4282720486323039
},
{
"loss": 5.871116876602173,
"loss_size": 1.4981385171413422,
"loss_pdi": 0.9578391710917155,
"loss_ee": 0.7877674202124277,
"loss_delivery": 1.332635521888733,
"loss_biodist": 0.9970230559508005,
"loss_toxic": 0.29771314313014346
},
{
"loss": 4.818921804428101,
"loss_size": 0.6080186615387598,
"loss_pdi": 0.8167077898979187,
"loss_ee": 0.7845198512077332,
"loss_delivery": 1.421961595614751,
"loss_biodist": 0.9678046603997549,
"loss_toxic": 0.21990922341744104
},
{
"loss": 4.394984285036723,
"loss_size": 0.3149821311235428,
"loss_pdi": 0.7380032440026602,
"loss_ee": 0.7671190500259399,
"loss_delivery": 1.447837049762408,
"loss_biodist": 0.9290279944737753,
"loss_toxic": 0.19801472003261247
},
{
"loss": 3.884276866912842,
"loss_size": 0.26142628739277524,
"loss_pdi": 0.6927057504653931,
"loss_ee": 0.7489989002545675,
"loss_delivery": 1.1156622817118962,
"loss_biodist": 0.8776024182637533,
"loss_toxic": 0.1878812573850155
},
{
"loss": 3.709937810897827,
"loss_size": 0.37482741723457974,
"loss_pdi": 0.6356837352116903,
"loss_ee": 0.7201081812381744,
"loss_delivery": 1.0413622185587883,
"loss_biodist": 0.769838293393453,
"loss_toxic": 0.16811783549686274
},
{
"loss": 3.3223368724187217,
"loss_size": 0.2915526789923509,
"loss_pdi": 0.5975265900293986,
"loss_ee": 0.657735288143158,
"loss_delivery": 0.915939765671889,
"loss_biodist": 0.6836796899636587,
"loss_toxic": 0.17590284595886865
},
{
"loss": 3.22677751382192,
"loss_size": 0.30531836052735645,
"loss_pdi": 0.5498206416765848,
"loss_ee": 0.6144644021987915,
"loss_delivery": 1.0515401139855385,
"loss_biodist": 0.5715167572100958,
"loss_toxic": 0.13411726988852024
},
{
"loss": 2.898585855960846,
"loss_size": 0.26108303914467496,
"loss_pdi": 0.5470793843269348,
"loss_ee": 0.5787178675333658,
"loss_delivery": 0.9020876735448837,
"loss_biodist": 0.47204887370268506,
"loss_toxic": 0.13756892209251723
},
{
"loss": 2.6754438877105713,
"loss_size": 0.27526700248320896,
"loss_pdi": 0.5229905943075815,
"loss_ee": 0.5479903370141983,
"loss_delivery": 0.7761634774506092,
"loss_biodist": 0.44562476873397827,
"loss_toxic": 0.10740765929222107
},
{
"loss": 2.4977574348449707,
"loss_size": 0.2336499529580275,
"loss_pdi": 0.4902036637067795,
"loss_ee": 0.5169308086236318,
"loss_delivery": 0.7469637009004751,
"loss_biodist": 0.40839699904123944,
"loss_toxic": 0.10161229533453782
},
{
"loss": 2.384280482927958,
"loss_size": 0.2612900485595067,
"loss_pdi": 0.48895320296287537,
"loss_ee": 0.49676452577114105,
"loss_delivery": 0.6831860815485319,
"loss_biodist": 0.36347728967666626,
"loss_toxic": 0.090609318887194
},
{
"loss": 2.3147188226381936,
"loss_size": 0.25488365814089775,
"loss_pdi": 0.4683869779109955,
"loss_ee": 0.4875288059314092,
"loss_delivery": 0.6712647005915642,
"loss_biodist": 0.35311167935530346,
"loss_toxic": 0.0795429985349377
},
{
"loss": 2.2636735240618386,
"loss_size": 0.25234917054573697,
"loss_pdi": 0.48236118257045746,
"loss_ee": 0.4705241521199544,
"loss_delivery": 0.6308227330446243,
"loss_biodist": 0.3382392128308614,
"loss_toxic": 0.0893770344555378
},
{
"loss": 2.116434315840403,
"loss_size": 0.24041289339462915,
"loss_pdi": 0.45973017315069836,
"loss_ee": 0.4558122158050537,
"loss_delivery": 0.5561455090840658,
"loss_biodist": 0.3245606869459152,
"loss_toxic": 0.07977284273753564
},
{
"loss": 2.116472323735555,
"loss_size": 0.22036718018352985,
"loss_pdi": 0.4676543176174164,
"loss_ee": 0.43863776326179504,
"loss_delivery": 0.6074383656183878,
"loss_biodist": 0.3068434993426005,
"loss_toxic": 0.07553133244315784
},
{
"loss": 2.0553239782651267,
"loss_size": 0.2092201883594195,
"loss_pdi": 0.45897047221660614,
"loss_ee": 0.44677118460337323,
"loss_delivery": 0.559788204729557,
"loss_biodist": 0.31095271309216815,
"loss_toxic": 0.06962121867885192
},
{
"loss": 1.961161474386851,
"loss_size": 0.20881337051590285,
"loss_pdi": 0.45118602613608044,
"loss_ee": 0.42851509153842926,
"loss_delivery": 0.49867390592892963,
"loss_biodist": 0.3025246188044548,
"loss_toxic": 0.07144851268579562
},
{
"loss": 1.959239919980367,
"loss_size": 0.2156102918088436,
"loss_pdi": 0.4368931899468104,
"loss_ee": 0.42987839380900067,
"loss_delivery": 0.5220988343159357,
"loss_biodist": 0.2822929248213768,
"loss_toxic": 0.07246625237166882
},
{
"loss": 1.9401288827260335,
"loss_size": 0.21518264586726824,
"loss_pdi": 0.45129939913749695,
"loss_ee": 0.4190533608198166,
"loss_delivery": 0.5097754697004954,
"loss_biodist": 0.27901960412661236,
"loss_toxic": 0.06579839282979567
}
],
"val": [
{
"loss": 17.15616934640067,
"loss_size": 12.466235705784388,
"loss_pdi": 1.2282596485955375,
"loss_ee": 1.1713778802326746,
"loss_delivery": 0.45601305684873034,
"loss_biodist": 1.3372918111937386,
"loss_toxic": 0.49699102129255024,
"acc_pdi": 0.6410256410256411,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 9.976628576006208,
"loss_size": 5.5366509301321845,
"loss_pdi": 1.1110572814941406,
"loss_ee": 1.2255418130329676,
"loss_delivery": 0.4425822538988931,
"loss_biodist": 1.2637801681246077,
"loss_toxic": 0.3970160186290741,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 6.142949717385428,
"loss_size": 2.012831313269479,
"loss_pdi": 0.9936460341726031,
"loss_ee": 1.2630535023553031,
"loss_delivery": 0.45085248563970837,
"loss_biodist": 1.1773428661482674,
"loss_toxic": 0.24522356688976288,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.523837975093296,
"loss_size": 0.6790421732834407,
"loss_pdi": 0.8449332543781826,
"loss_ee": 1.2828129359654017,
"loss_delivery": 0.47022694775036405,
"loss_biodist": 1.0964353680610657,
"loss_toxic": 0.15038717218807765,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.914861168180193,
"loss_size": 0.23712052990283286,
"loss_pdi": 0.739460038287299,
"loss_ee": 1.307591336114066,
"loss_delivery": 0.48760053728307995,
"loss_biodist": 1.0441895723342896,
"loss_toxic": 0.09889917501381465,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.827185903276716,
"loss_size": 0.19252057054213115,
"loss_pdi": 0.6754980896200452,
"loss_ee": 1.3178976603916712,
"loss_delivery": 0.52304607629776,
"loss_biodist": 1.0516272272382463,
"loss_toxic": 0.06659629621676036,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.7409281730651855,
"loss_size": 0.19949970713683537,
"loss_pdi": 0.63478513274874,
"loss_ee": 1.3695382390703474,
"loss_delivery": 0.5511368151221957,
"loss_biodist": 0.9442958320890155,
"loss_toxic": 0.04167233567152705,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.04153094972883,
"loss_size": 0.18925889155694417,
"loss_pdi": 0.6154808104038239,
"loss_ee": 1.4497678790773665,
"loss_delivery": 0.6580333965165275,
"loss_biodist": 1.0802616902760096,
"loss_toxic": 0.04872834203498704,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.0117951801845,
"loss_size": 0.261259101331234,
"loss_pdi": 0.6064834722450801,
"loss_ee": 1.4382908344268799,
"loss_delivery": 0.6485261533941541,
"loss_biodist": 1.013489259140832,
"loss_toxic": 0.04374648204871586,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.021411214556013,
"loss_size": 0.25430602048124584,
"loss_pdi": 0.6047579922846386,
"loss_ee": 1.4375121252877372,
"loss_delivery": 0.794183360678809,
"loss_biodist": 0.8880592542035239,
"loss_toxic": 0.04259242117404938,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.41025641025641024,
"acc_toxic": 1.0
},
{
"loss": 4.389863218579974,
"loss_size": 0.47541433785642895,
"loss_pdi": 0.6053804946797234,
"loss_ee": 1.5225199971880232,
"loss_delivery": 0.864760024206979,
"loss_biodist": 0.8813395031860897,
"loss_toxic": 0.04044891867254462,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.2736298356737406,
"loss_size": 0.4230478884918349,
"loss_pdi": 0.6034957447222301,
"loss_ee": 1.4814096518925257,
"loss_delivery": 0.9466098589556557,
"loss_biodist": 0.7808983538831983,
"loss_toxic": 0.03816834863807474,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4,
"acc_toxic": 1.0
},
{
"loss": 4.585977554321289,
"loss_size": 0.4787849709391594,
"loss_pdi": 0.6088685957448823,
"loss_ee": 1.5371792827333723,
"loss_delivery": 1.1909210256167821,
"loss_biodist": 0.7510551980563572,
"loss_toxic": 0.01916840286659343,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.317904029573713,
"loss_size": 0.46495475407157627,
"loss_pdi": 0.6024806403688022,
"loss_ee": 1.5152796847479684,
"loss_delivery": 1.0274277882916587,
"loss_biodist": 0.6894632875919342,
"loss_toxic": 0.018297837914100716,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.287515469959804,
"loss_size": 0.5376717469521931,
"loss_pdi": 0.5981708924685206,
"loss_ee": 1.4651154450007848,
"loss_delivery": 1.0260179340839386,
"loss_biodist": 0.6434041815144675,
"loss_toxic": 0.01713526714593172,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.38974358974358975,
"acc_toxic": 1.0
},
{
"loss": 4.683320045471191,
"loss_size": 0.5271316319704056,
"loss_pdi": 0.6275609007903508,
"loss_ee": 1.5593642677579607,
"loss_delivery": 1.2529506555625372,
"loss_biodist": 0.7023919905935015,
"loss_toxic": 0.013920623875622238,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.397426230566842,
"loss_size": 0.5106346564633506,
"loss_pdi": 0.6181366369128227,
"loss_ee": 1.5676180635179793,
"loss_delivery": 1.0354454517364502,
"loss_biodist": 0.6522124835423061,
"loss_toxic": 0.013378978440804141,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.41025641025641024,
"acc_toxic": 1.0
},
{
"loss": 4.412310838699341,
"loss_size": 0.5362906115395683,
"loss_pdi": 0.6130111419728824,
"loss_ee": 1.5929286650248937,
"loss_delivery": 1.001524874142238,
"loss_biodist": 0.6563895855631147,
"loss_toxic": 0.012165952806494065,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.78386572429112,
"loss_size": 0.5134438020842416,
"loss_pdi": 0.6321895952735629,
"loss_ee": 1.5953963484082903,
"loss_delivery": 1.3276494145393372,
"loss_biodist": 0.7046042127268655,
"loss_toxic": 0.010582369419613056,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.835996287209647,
"loss_size": 0.4968718098742621,
"loss_pdi": 0.6234481834939548,
"loss_ee": 1.5689009257725306,
"loss_delivery": 1.4453607542174203,
"loss_biodist": 0.6902376328195844,
"loss_toxic": 0.011176949766065394,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.41025641025641024,
"acc_toxic": 1.0
},
{
"loss": 4.646918603352138,
"loss_size": 0.49231292733124327,
"loss_pdi": 0.6135136048708644,
"loss_ee": 1.5504994562694006,
"loss_delivery": 1.3212553603308541,
"loss_biodist": 0.6578856153147561,
"loss_toxic": 0.01145164788301502,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.40512820512820513,
"acc_toxic": 1.0
},
{
"loss": 4.580311230250767,
"loss_size": 0.529256243790899,
"loss_pdi": 0.6140108300106866,
"loss_ee": 1.5573444025857108,
"loss_delivery": 1.2449579962662287,
"loss_biodist": 0.6246597383703504,
"loss_toxic": 0.01008199481293559,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.41025641025641024,
"acc_toxic": 1.0
}
]
}

Binary file not shown.

View File

@ -0,0 +1,573 @@
{
"train": [
{
"loss": 21.41742353439331,
"loss_size": 15.701433181762695,
"loss_pdi": 1.3632826924324035,
"loss_ee": 1.0266466200351716,
"loss_delivery": 1.418760311603546,
"loss_biodist": 1.2302423000335694,
"loss_toxic": 0.6770588457584381
},
{
"loss": 9.57167477607727,
"loss_size": 4.903112936019897,
"loss_pdi": 1.140204393863678,
"loss_ee": 1.0074746072292329,
"loss_delivery": 1.1362140402197838,
"loss_biodist": 1.0133467674255372,
"loss_toxic": 0.3713219165802002
},
{
"loss": 5.485435938835144,
"loss_size": 1.4459647357463836,
"loss_pdi": 0.9086057424545289,
"loss_ee": 0.9852841377258301,
"loss_delivery": 1.0618216753005982,
"loss_biodist": 0.8613051772117615,
"loss_toxic": 0.22245449870824813
},
{
"loss": 4.065439391136169,
"loss_size": 0.45657781660556795,
"loss_pdi": 0.7768359065055848,
"loss_ee": 0.953633052110672,
"loss_delivery": 0.9996832102537155,
"loss_biodist": 0.7116856783628464,
"loss_toxic": 0.16702365390956403
},
{
"loss": 3.687756299972534,
"loss_size": 0.30417054146528244,
"loss_pdi": 0.7158299326896668,
"loss_ee": 0.9035742998123169,
"loss_delivery": 1.060212180018425,
"loss_biodist": 0.5556510210037231,
"loss_toxic": 0.1483182568103075
},
{
"loss": 3.536502242088318,
"loss_size": 0.3545849896967411,
"loss_pdi": 0.6672206580638885,
"loss_ee": 0.8711013793945312,
"loss_delivery": 1.0780200093984604,
"loss_biodist": 0.4400379478931427,
"loss_toxic": 0.12553719095885754
},
{
"loss": 3.289654517173767,
"loss_size": 0.3215350516140461,
"loss_pdi": 0.6445666253566742,
"loss_ee": 0.8839016199111939,
"loss_delivery": 0.9652104169130326,
"loss_biodist": 0.35911705493927004,
"loss_toxic": 0.11532373651862145
},
{
"loss": 3.199695384502411,
"loss_size": 0.29793725311756136,
"loss_pdi": 0.6349938422441482,
"loss_ee": 0.8508742034435273,
"loss_delivery": 0.9878654226660728,
"loss_biodist": 0.3253925606608391,
"loss_toxic": 0.10263208523392678
},
{
"loss": 2.981464409828186,
"loss_size": 0.2926298946142197,
"loss_pdi": 0.6041542202234268,
"loss_ee": 0.8202637135982513,
"loss_delivery": 0.9215006068348884,
"loss_biodist": 0.2569382354617119,
"loss_toxic": 0.08597766645252705
},
{
"loss": 2.7703840017318724,
"loss_size": 0.30560431331396104,
"loss_pdi": 0.5837061107158661,
"loss_ee": 0.7886367738246918,
"loss_delivery": 0.763620425760746,
"loss_biodist": 0.23922762870788575,
"loss_toxic": 0.08958881739526987
},
{
"loss": 2.6885447978973387,
"loss_size": 0.28905282765626905,
"loss_pdi": 0.5664507508277893,
"loss_ee": 0.7639070689678192,
"loss_delivery": 0.7704478114843368,
"loss_biodist": 0.21979134678840637,
"loss_toxic": 0.07889490202069283
},
{
"loss": 2.591668117046356,
"loss_size": 0.2613472960889339,
"loss_pdi": 0.5540884166955948,
"loss_ee": 0.7465697586536407,
"loss_delivery": 0.7534924671053886,
"loss_biodist": 0.20271824076771736,
"loss_toxic": 0.07345191687345505
},
{
"loss": 2.482152557373047,
"loss_size": 0.2550207316875458,
"loss_pdi": 0.5377364099025727,
"loss_ee": 0.7093640863895416,
"loss_delivery": 0.715790644288063,
"loss_biodist": 0.1968037411570549,
"loss_toxic": 0.0674369728192687
},
{
"loss": 2.4853516697883604,
"loss_size": 0.25445577800273894,
"loss_pdi": 0.5285622417926789,
"loss_ee": 0.7045736134052276,
"loss_delivery": 0.765293450653553,
"loss_biodist": 0.16704678535461426,
"loss_toxic": 0.06541981641203165
},
{
"loss": 2.325911021232605,
"loss_size": 0.22311154529452323,
"loss_pdi": 0.5131498754024506,
"loss_ee": 0.6778043508529663,
"loss_delivery": 0.6851547978818416,
"loss_biodist": 0.16465196907520294,
"loss_toxic": 0.06203848272562027
},
{
"loss": 2.213776695728302,
"loss_size": 0.24956582188606263,
"loss_pdi": 0.48536362051963805,
"loss_ee": 0.6791646689176559,
"loss_delivery": 0.5861033886671067,
"loss_biodist": 0.15450086519122125,
"loss_toxic": 0.059078306704759595
},
{
"loss": 2.3086095094680785,
"loss_size": 0.20203103460371494,
"loss_pdi": 0.4988987535238266,
"loss_ee": 0.6785910665988922,
"loss_delivery": 0.7291694968938828,
"loss_biodist": 0.1462075024843216,
"loss_toxic": 0.05371163971722126
},
{
"loss": 2.0882336378097532,
"loss_size": 0.2287739872932434,
"loss_pdi": 0.5072675496339798,
"loss_ee": 0.6701794564723969,
"loss_delivery": 0.4832353606820107,
"loss_biodist": 0.1421804867684841,
"loss_toxic": 0.05659680655226111
},
{
"loss": 2.0332674741744996,
"loss_size": 0.22676872164011003,
"loss_pdi": 0.4669553279876709,
"loss_ee": 0.6482236534357071,
"loss_delivery": 0.5079896807670593,
"loss_biodist": 0.13781072497367858,
"loss_toxic": 0.04551935400813818
},
{
"loss": 1.9842296838760376,
"loss_size": 0.2007827118039131,
"loss_pdi": 0.4655294865369797,
"loss_ee": 0.6295293152332306,
"loss_delivery": 0.5032630048692226,
"loss_biodist": 0.12692373394966125,
"loss_toxic": 0.05820149295032025
},
{
"loss": 1.9574703454971314,
"loss_size": 0.20575413331389428,
"loss_pdi": 0.4662397414445877,
"loss_ee": 0.6207350552082062,
"loss_delivery": 0.4868774816393852,
"loss_biodist": 0.1334032252430916,
"loss_toxic": 0.04446075968444348
},
{
"loss": 1.8693695425987245,
"loss_size": 0.2003278151154518,
"loss_pdi": 0.4481669098138809,
"loss_ee": 0.6228156566619873,
"loss_delivery": 0.4377666234970093,
"loss_biodist": 0.1222815040498972,
"loss_toxic": 0.03801098903641105
},
{
"loss": 1.9101393103599549,
"loss_size": 0.22527543231844901,
"loss_pdi": 0.4501156389713287,
"loss_ee": 0.5992490768432617,
"loss_delivery": 0.4722588837146759,
"loss_biodist": 0.12016024515032768,
"loss_toxic": 0.04308005180209875
},
{
"loss": 1.8186616897583008,
"loss_size": 0.20228504091501237,
"loss_pdi": 0.4353774756193161,
"loss_ee": 0.595091101527214,
"loss_delivery": 0.42390944324433805,
"loss_biodist": 0.12379681393504142,
"loss_toxic": 0.03820184739306569
},
{
"loss": 1.753528320789337,
"loss_size": 0.168972497433424,
"loss_pdi": 0.4384701639413834,
"loss_ee": 0.5804536670446396,
"loss_delivery": 0.4066555552184582,
"loss_biodist": 0.12041406258940697,
"loss_toxic": 0.038562366552650926
},
{
"loss": 1.752294898033142,
"loss_size": 0.16247735619544984,
"loss_pdi": 0.4280344098806381,
"loss_ee": 0.5781133621931076,
"loss_delivery": 0.4378485083580017,
"loss_biodist": 0.1110315527766943,
"loss_toxic": 0.03478973265737295
},
{
"loss": 1.8295514345169068,
"loss_size": 0.1906499370932579,
"loss_pdi": 0.42695977091789244,
"loss_ee": 0.6036634147167206,
"loss_delivery": 0.4585780970752239,
"loss_biodist": 0.10937272906303405,
"loss_toxic": 0.04032747393939644
}
],
"val": [
{
"loss": 12.90866756439209,
"loss_size": 8.614875793457031,
"loss_pdi": 1.3393473625183105,
"loss_ee": 0.7915782928466797,
"loss_delivery": 0.35056664049625397,
"loss_biodist": 1.2413995265960693,
"loss_toxic": 0.5708996057510376,
"acc_pdi": 0.09803921568627451,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 6.046682596206665,
"loss_size": 2.181392192840576,
"loss_pdi": 1.010448306798935,
"loss_ee": 0.7943653464317322,
"loss_delivery": 0.3778253495693207,
"loss_biodist": 1.1818514466285706,
"loss_toxic": 0.5007997751235962,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.130645751953125,
"loss_size": 0.4065478593111038,
"loss_pdi": 0.7734719514846802,
"loss_ee": 0.7856209874153137,
"loss_delivery": 0.4249718487262726,
"loss_biodist": 1.1932182908058167,
"loss_toxic": 0.5468149781227112,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.74097740650177,
"loss_size": 0.05154523253440857,
"loss_pdi": 0.6415763646364212,
"loss_ee": 0.7410732209682465,
"loss_delivery": 0.38766030967235565,
"loss_biodist": 1.2818186283111572,
"loss_toxic": 0.6373037025332451,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.6484163999557495,
"loss_size": 0.05830332264304161,
"loss_pdi": 0.5795184522867203,
"loss_ee": 0.6541054546833038,
"loss_delivery": 0.40608713030815125,
"loss_biodist": 1.2502482235431671,
"loss_toxic": 0.7001538649201393,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.8285857439041138,
"loss_size": 0.07835924997925758,
"loss_pdi": 0.5752626657485962,
"loss_ee": 0.7419937252998352,
"loss_delivery": 0.382533997297287,
"loss_biodist": 1.3310475945472717,
"loss_toxic": 0.7193883210420609,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.6414456367492676,
"loss_size": 0.07554847374558449,
"loss_pdi": 0.6094752848148346,
"loss_ee": 0.7212981283664703,
"loss_delivery": 0.4330318123102188,
"loss_biodist": 1.1887046694755554,
"loss_toxic": 0.6133871823549271,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.537736415863037,
"loss_size": 0.10301502048969269,
"loss_pdi": 0.5892050117254257,
"loss_ee": 0.6979120671749115,
"loss_delivery": 0.5159651935100555,
"loss_biodist": 1.064111590385437,
"loss_toxic": 0.5675273537635803,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.6027194261550903,
"loss_size": 0.16463171318173409,
"loss_pdi": 0.5593049973249435,
"loss_ee": 0.6619580686092377,
"loss_delivery": 0.538649171590805,
"loss_biodist": 1.0742176473140717,
"loss_toxic": 0.6039578504860401,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.762178421020508,
"loss_size": 0.22026687115430832,
"loss_pdi": 0.5431422591209412,
"loss_ee": 0.7102161943912506,
"loss_delivery": 0.44816476106643677,
"loss_biodist": 1.172889918088913,
"loss_toxic": 0.6674983687698841,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.7515294551849365,
"loss_size": 0.2188272960484028,
"loss_pdi": 0.5569919049739838,
"loss_ee": 0.6500461399555206,
"loss_delivery": 0.4802835136651993,
"loss_biodist": 1.140358328819275,
"loss_toxic": 0.7050221972167492,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.47284197807312,
"loss_size": 0.19455180317163467,
"loss_pdi": 0.5266519337892532,
"loss_ee": 0.6175068914890289,
"loss_delivery": 0.5526535362005234,
"loss_biodist": 0.9517507255077362,
"loss_toxic": 0.6297270879149437,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.9374399185180664,
"loss_size": 0.32401855289936066,
"loss_pdi": 0.5553185790777206,
"loss_ee": 0.6632668077945709,
"loss_delivery": 0.48028646409511566,
"loss_biodist": 1.1518925726413727,
"loss_toxic": 0.7626568526029587,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.7636899948120117,
"loss_size": 0.2452084794640541,
"loss_pdi": 0.5189685672521591,
"loss_ee": 0.6509725153446198,
"loss_delivery": 0.38138364255428314,
"loss_biodist": 1.1877531707286835,
"loss_toxic": 0.7794036716222763,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8235294117647058,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.837575674057007,
"loss_size": 0.3500683009624481,
"loss_pdi": 0.5655115246772766,
"loss_ee": 0.6301239728927612,
"loss_delivery": 0.47451916337013245,
"loss_biodist": 1.069214403629303,
"loss_toxic": 0.7481381297111511,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.834155559539795,
"loss_size": 0.27750423550605774,
"loss_pdi": 0.513676255941391,
"loss_ee": 0.6291456520557404,
"loss_delivery": 0.40621335804462433,
"loss_biodist": 1.1680629253387451,
"loss_toxic": 0.8395530804991722,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8235294117647058,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.6893343925476074,
"loss_size": 0.34515636414289474,
"loss_pdi": 0.5478167533874512,
"loss_ee": 0.6345725357532501,
"loss_delivery": 0.42669548094272614,
"loss_biodist": 1.0210922062397003,
"loss_toxic": 0.7140011191368103,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.00758171081543,
"loss_size": 0.3066108226776123,
"loss_pdi": 0.5529182702302933,
"loss_ee": 0.6956813335418701,
"loss_delivery": 0.4755028784275055,
"loss_biodist": 1.152823954820633,
"loss_toxic": 0.8240445479750633,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.974515676498413,
"loss_size": 0.3346070274710655,
"loss_pdi": 0.5447590947151184,
"loss_ee": 0.6450685858726501,
"loss_delivery": 0.4816073626279831,
"loss_biodist": 1.1390976309776306,
"loss_toxic": 0.8293759152293205,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.880821943283081,
"loss_size": 0.282451294362545,
"loss_pdi": 0.5469185262918472,
"loss_ee": 0.6417804062366486,
"loss_delivery": 0.46756358444690704,
"loss_biodist": 1.127053290605545,
"loss_toxic": 0.8150547966361046,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.8784565925598145,
"loss_size": 0.2880236804485321,
"loss_pdi": 0.5311962813138962,
"loss_ee": 0.6260144710540771,
"loss_delivery": 0.4471806138753891,
"loss_biodist": 1.1587725579738617,
"loss_toxic": 0.8272688835859299,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.9609978199005127,
"loss_size": 0.3277027904987335,
"loss_pdi": 0.5816705077886581,
"loss_ee": 0.6382596492767334,
"loss_delivery": 0.42079347372055054,
"loss_biodist": 1.1735960245132446,
"loss_toxic": 0.8189755231142044,
"acc_pdi": 0.9019607843137255,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.8955435752868652,
"loss_size": 0.28625721484422684,
"loss_pdi": 0.5423711687326431,
"loss_ee": 0.6477845013141632,
"loss_delivery": 0.43728773295879364,
"loss_biodist": 1.1806198358535767,
"loss_toxic": 0.801223024725914,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.007414698600769,
"loss_size": 0.3294255882501602,
"loss_pdi": 0.5471038520336151,
"loss_ee": 0.6502591967582703,
"loss_delivery": 0.4436507970094681,
"loss_biodist": 1.202763706445694,
"loss_toxic": 0.8342117220163345,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.050978183746338,
"loss_size": 0.39942242205142975,
"loss_pdi": 0.5590354651212692,
"loss_ee": 0.6409733295440674,
"loss_delivery": 0.45262810587882996,
"loss_biodist": 1.1648951172828674,
"loss_toxic": 0.8340234756469727,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.86995267868042,
"loss_size": 0.2950260043144226,
"loss_pdi": 0.5328812152147293,
"loss_ee": 0.6263198852539062,
"loss_delivery": 0.45980168879032135,
"loss_biodist": 1.1475371420383453,
"loss_toxic": 0.808386467397213,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.8100762367248535,
"loss_size": 0.2826509475708008,
"loss_pdi": 0.5344351083040237,
"loss_ee": 0.6322177648544312,
"loss_delivery": 0.4711516797542572,
"loss_biodist": 1.1184114813804626,
"loss_toxic": 0.7712092772126198,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.803921568627451,
"acc_toxic": 0.851063829787234
}
]
}

Binary file not shown.

View File

@ -0,0 +1,405 @@
{
"train": [
{
"loss": 17.82338151064786,
"loss_size": 12.698930350216953,
"loss_pdi": 1.2952325452457776,
"loss_ee": 1.044789281758395,
"loss_delivery": 0.894258591261777,
"loss_biodist": 1.3406665433536877,
"loss_toxic": 0.5495036569508639
},
{
"loss": 6.6719701940363105,
"loss_size": 2.361687348647551,
"loss_pdi": 0.9808245301246643,
"loss_ee": 0.9910268241708929,
"loss_delivery": 0.8210993910377676,
"loss_biodist": 1.184872735630382,
"loss_toxic": 0.3324593657797033
},
{
"loss": 4.197543209249323,
"loss_size": 0.3305926668373021,
"loss_pdi": 0.8021349581805143,
"loss_ee": 0.9533808014609597,
"loss_delivery": 0.85791384361007,
"loss_biodist": 1.0029216788031838,
"loss_toxic": 0.2505992312322963
},
{
"loss": 3.817467537793246,
"loss_size": 0.3109576465053992,
"loss_pdi": 0.7077692205255682,
"loss_ee": 0.9259536266326904,
"loss_delivery": 0.8590928573500026,
"loss_biodist": 0.7878239534117959,
"loss_toxic": 0.22587016597390175
},
{
"loss": 3.426318342035467,
"loss_size": 0.30332070047205145,
"loss_pdi": 0.6738776402039961,
"loss_ee": 0.8633853034539656,
"loss_delivery": 0.810848053206097,
"loss_biodist": 0.6030096492984078,
"loss_toxic": 0.17187697711316022
},
{
"loss": 3.243085037578236,
"loss_size": 0.2643032656474547,
"loss_pdi": 0.680456202138554,
"loss_ee": 0.8441235260529951,
"loss_delivery": 0.790125925432552,
"loss_biodist": 0.4977728453549472,
"loss_toxic": 0.16630326211452484
},
{
"loss": 3.0106391689994116,
"loss_size": 0.2923726256598126,
"loss_pdi": 0.6462291912599043,
"loss_ee": 0.8259676423939791,
"loss_delivery": 0.7123563428494063,
"loss_biodist": 0.4204860031604767,
"loss_toxic": 0.11322732371362773
},
{
"loss": 2.6575805924155493,
"loss_size": 0.23369696668603204,
"loss_pdi": 0.5998414809053595,
"loss_ee": 0.7827209342609752,
"loss_delivery": 0.6251701191067696,
"loss_biodist": 0.324639000675895,
"loss_toxic": 0.09151209010319276
},
{
"loss": 2.6034729263999243,
"loss_size": 0.20161977207118814,
"loss_pdi": 0.5999934835867449,
"loss_ee": 0.7699762636964972,
"loss_delivery": 0.6784080802039667,
"loss_biodist": 0.2691292708570307,
"loss_toxic": 0.08434606292708353
},
{
"loss": 2.481573982672258,
"loss_size": 0.2506237829273397,
"loss_pdi": 0.5553325793959878,
"loss_ee": 0.702888163653287,
"loss_delivery": 0.6528384658423337,
"loss_biodist": 0.25332851166074927,
"loss_toxic": 0.06656241179867224
},
{
"loss": 2.328354239463806,
"loss_size": 0.20882186158136887,
"loss_pdi": 0.5448922569101508,
"loss_ee": 0.7051796858960931,
"loss_delivery": 0.5663092088970271,
"loss_biodist": 0.23630927367643875,
"loss_toxic": 0.06684201947328719
},
{
"loss": 2.1841621182181616,
"loss_size": 0.20124886184930801,
"loss_pdi": 0.5195452462543141,
"loss_ee": 0.6754790571602908,
"loss_delivery": 0.5004782859574665,
"loss_biodist": 0.2269973118196834,
"loss_toxic": 0.06041339209133929
},
{
"loss": 2.218748081814159,
"loss_size": 0.1988528309897943,
"loss_pdi": 0.5231576589020815,
"loss_ee": 0.6777869246222756,
"loss_delivery": 0.5498508228497072,
"loss_biodist": 0.20798503336581317,
"loss_toxic": 0.06111483716151931
},
{
"loss": 2.1877676248550415,
"loss_size": 0.20004721256819638,
"loss_pdi": 0.5174875286492434,
"loss_ee": 0.6804409514773976,
"loss_delivery": 0.5348355831070379,
"loss_biodist": 0.2000916667959907,
"loss_toxic": 0.054864687167785385
},
{
"loss": 2.0138474811207163,
"loss_size": 0.19752372259443457,
"loss_pdi": 0.4868703836744482,
"loss_ee": 0.6450781768018549,
"loss_delivery": 0.44826274839314545,
"loss_biodist": 0.18033211339603772,
"loss_toxic": 0.05578036225316199
},
{
"loss": 2.011049357327548,
"loss_size": 0.18426605652679096,
"loss_pdi": 0.4944283068180084,
"loss_ee": 0.6483827070756392,
"loss_delivery": 0.4382695244117217,
"loss_biodist": 0.19008470394394614,
"loss_toxic": 0.05561802129853855
},
{
"loss": 2.0496722134676846,
"loss_size": 0.18194164742122998,
"loss_pdi": 0.4917480132796548,
"loss_ee": 0.6405326967889612,
"loss_delivery": 0.5158521959727461,
"loss_biodist": 0.1772592243823138,
"loss_toxic": 0.042338407937098636
},
{
"loss": 1.9639968113465742,
"loss_size": 0.1595074331218546,
"loss_pdi": 0.48567668145353143,
"loss_ee": 0.630230272358114,
"loss_delivery": 0.4618720757690343,
"loss_biodist": 0.17742999304424634,
"loss_toxic": 0.0492804107171568
},
{
"loss": 1.9399811571294612,
"loss_size": 0.18978103656660428,
"loss_pdi": 0.48150675947015936,
"loss_ee": 0.6243780959736217,
"loss_delivery": 0.4384633018211885,
"loss_biodist": 0.16038654270497235,
"loss_toxic": 0.045465379022061825
}
],
"val": [
{
"loss": 8.991046905517578,
"loss_size": 4.490478515625,
"loss_pdi": 0.9734575947125753,
"loss_ee": 0.9441032012303671,
"loss_delivery": 1.2459152142206829,
"loss_biodist": 0.9957353870073954,
"loss_toxic": 0.34135735034942627,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 3.4952890078226724,
"loss_size": 0.27914075056711835,
"loss_pdi": 0.6866898337999979,
"loss_ee": 0.7263152599334717,
"loss_delivery": 0.7405761281649271,
"loss_biodist": 0.9172844588756561,
"loss_toxic": 0.14528251190980276,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 2.8096923430760703,
"loss_size": 0.04479753350218137,
"loss_pdi": 0.48064420620600384,
"loss_ee": 0.6685001452763876,
"loss_delivery": 0.7262534300486246,
"loss_biodist": 0.827881266673406,
"loss_toxic": 0.06161577875415484,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 2.756531000137329,
"loss_size": 0.07622841248909633,
"loss_pdi": 0.463245431582133,
"loss_ee": 0.669317622979482,
"loss_delivery": 0.6688057581583658,
"loss_biodist": 0.8275523781776428,
"loss_toxic": 0.05138145387172699,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 2.9385202328364053,
"loss_size": 0.1038126324613889,
"loss_pdi": 0.46456684668858844,
"loss_ee": 0.705346941947937,
"loss_delivery": 0.7891722718874613,
"loss_biodist": 0.82027334968249,
"loss_toxic": 0.055348185201485954,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 2.8632373809814453,
"loss_size": 0.056236049781243004,
"loss_pdi": 0.44324543078740436,
"loss_ee": 0.6018350621064504,
"loss_delivery": 0.9420756896336874,
"loss_biodist": 0.755872001250585,
"loss_toxic": 0.06397318094968796,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 2.993934949239095,
"loss_size": 0.08202650646368663,
"loss_pdi": 0.508717954158783,
"loss_ee": 0.7680216828982035,
"loss_delivery": 0.8612345655759176,
"loss_biodist": 0.7118468681971232,
"loss_toxic": 0.0620873523876071,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.48484848484848486,
"acc_toxic": 1.0
},
{
"loss": 2.8506224155426025,
"loss_size": 0.03146359619374076,
"loss_pdi": 0.5297882954279581,
"loss_ee": 0.7625004649162292,
"loss_delivery": 0.802927960952123,
"loss_biodist": 0.6721820036570231,
"loss_toxic": 0.05176017774889866,
"acc_pdi": 0.8333333333333334,
"acc_ee": 0.4696969696969697,
"acc_toxic": 1.0
},
{
"loss": 2.8424178759256997,
"loss_size": 0.08800932268301646,
"loss_pdi": 0.540622721115748,
"loss_ee": 0.839806874593099,
"loss_delivery": 0.6362804472446442,
"loss_biodist": 0.6938393215338389,
"loss_toxic": 0.043859192014982305,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.4696969696969697,
"acc_toxic": 1.0
},
{
"loss": 2.9927186171213784,
"loss_size": 0.04866213048808277,
"loss_pdi": 0.5090581774711609,
"loss_ee": 0.8101900815963745,
"loss_delivery": 0.8959137797355652,
"loss_biodist": 0.697561984260877,
"loss_toxic": 0.031332316963622965,
"acc_pdi": 0.8333333333333334,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 2.9314780632654824,
"loss_size": 0.06920161470770836,
"loss_pdi": 0.5283850431442261,
"loss_ee": 0.8220112522443136,
"loss_delivery": 0.7353424032529196,
"loss_biodist": 0.7471547623475393,
"loss_toxic": 0.029382963587219518,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 2.9559961954752603,
"loss_size": 0.05500282160937786,
"loss_pdi": 0.516402949889501,
"loss_ee": 0.8015920122464498,
"loss_delivery": 0.7959124445915222,
"loss_biodist": 0.7596821735302607,
"loss_toxic": 0.0274037744384259,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 2.9268057346343994,
"loss_size": 0.07737081746260326,
"loss_pdi": 0.5241717199484507,
"loss_ee": 0.835136612256368,
"loss_delivery": 0.7045041720072428,
"loss_biodist": 0.7396295169989268,
"loss_toxic": 0.0459929151305308,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 2.9870657920837402,
"loss_size": 0.08099805439511935,
"loss_pdi": 0.5164442459742228,
"loss_ee": 0.8418577512105306,
"loss_delivery": 0.7609673937161764,
"loss_biodist": 0.7428288658459982,
"loss_toxic": 0.04396952743021151,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 2.979916493097941,
"loss_size": 0.09317472080389659,
"loss_pdi": 0.514099915822347,
"loss_ee": 0.8726487557093302,
"loss_delivery": 0.6872047583262125,
"loss_biodist": 0.7711265037457148,
"loss_toxic": 0.04166170318300525,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.36363636363636365,
"acc_toxic": 1.0
},
{
"loss": 2.937371532122294,
"loss_size": 0.06330848174790542,
"loss_pdi": 0.5085045297940572,
"loss_ee": 0.8129515051841736,
"loss_delivery": 0.7455949584643046,
"loss_biodist": 0.7728735754887263,
"loss_toxic": 0.03413854034927984,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 3.0136941274007163,
"loss_size": 0.0849883034825325,
"loss_pdi": 0.51878755291303,
"loss_ee": 0.785295327504476,
"loss_delivery": 0.8036739627520243,
"loss_biodist": 0.7789742400248846,
"loss_toxic": 0.041974871419370174,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 2.9988585313161216,
"loss_size": 0.053379556785027184,
"loss_pdi": 0.5192934771378835,
"loss_ee": 0.838801383972168,
"loss_delivery": 0.7629345854123434,
"loss_biodist": 0.7728350758552551,
"loss_toxic": 0.05161439681736132,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.36363636363636365,
"acc_toxic": 1.0
},
{
"loss": 2.9807325998942056,
"loss_size": 0.055528260146578155,
"loss_pdi": 0.5009754200776418,
"loss_ee": 0.8306655287742615,
"loss_delivery": 0.7751240134239197,
"loss_biodist": 0.7733635157346725,
"loss_toxic": 0.045075801705631115,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.36363636363636365,
"acc_toxic": 1.0
}
]
}

Binary file not shown.

View File

@ -0,0 +1,198 @@
{
"fold_results": [
{
"fold_idx": 0,
"n_samples": 95,
"size": {
"n": 95,
"rmse": 0.632240295394253,
"mae": 0.4674149764211554,
"r2": -0.13819694158244777
},
"delivery": {
"n": 66,
"rmse": 1.3553486689823926,
"mae": 0.48114709207562334,
"r2": -0.008234200686471516
},
"pdi": {
"n": 95,
"accuracy": 0.6105263157894737
},
"ee": {
"n": 95,
"accuracy": 0.6631578947368421
},
"toxic": {
"n": 66,
"accuracy": 0.8939393939393939
}
},
{
"fold_idx": 1,
"n_samples": 195,
"size": {
"n": 193,
"rmse": 0.42622752144538556,
"mae": 0.24566329575573226,
"r2": 0.0482086242002292
},
"delivery": {
"n": 123,
"rmse": 0.742899240869997,
"mae": 0.5315999669170507,
"r2": -0.03140039086191759
},
"pdi": {
"n": 195,
"accuracy": 0.7076923076923077
},
"ee": {
"n": 195,
"accuracy": 0.4205128205128205
},
"toxic": {
"n": 123,
"accuracy": 1.0
}
},
{
"fold_idx": 2,
"n_samples": 51,
"size": {
"n": 51,
"rmse": 0.241909571406037,
"mae": 0.20043573192521638,
"r2": -0.43487628292073
},
"delivery": {
"n": 44,
"rmse": 0.7564153649581582,
"mae": 0.6047130756302398,
"r2": -0.4226486727361405
},
"pdi": {
"n": 51,
"accuracy": 0.8823529411764706
},
"ee": {
"n": 51,
"accuracy": 0.8431372549019608
},
"toxic": {
"n": 47,
"accuracy": 0.851063829787234
}
},
{
"fold_idx": 3,
"n_samples": 66,
"size": {
"n": 66,
"rmse": 0.2857872773679936,
"mae": 0.22075237649859805,
"r2": -0.5674047032859011
},
"delivery": {
"n": 62,
"rmse": 1.0291312965402932,
"mae": 0.7422042032328224,
"r2": -0.7148264932933832
},
"pdi": {
"n": 66,
"accuracy": 0.8484848484848485
},
"ee": {
"n": 66,
"accuracy": 0.18181818181818182
},
"toxic": {
"n": 62,
"accuracy": 1.0
}
},
{
"fold_idx": 4,
"n_samples": 27,
"size": {
"n": 27,
"rmse": 0.2271495001169846,
"mae": 0.18753767013549805,
"r2": -0.19441156195074893
},
"delivery": {
"n": 15,
"rmse": 1.993006453768918,
"mae": 1.3779302000999452,
"r2": -0.3411461507368889
},
"pdi": {
"n": 27,
"accuracy": 0.8888888888888888
},
"ee": {
"n": 27,
"accuracy": 0.5925925925925926
},
"toxic": {
"n": 15,
"accuracy": 1.0
}
}
],
"summary_stats": {
"size": {
"rmse_mean": 0.36266283314613074,
"rmse_std": 0.15203127472757474,
"r2_mean": -0.2573361731079197,
"r2_std": 0.2187118059634264
},
"delivery": {
"rmse_mean": 1.1753602050239518,
"rmse_std": 0.46580283242073095,
"r2_mean": -0.30365118166296035,
"r2_std": 0.2630677092396549
},
"pdi": {
"accuracy_mean": 0.7875890604063979,
"accuracy_std": 0.11016791908756088
},
"ee": {
"accuracy_mean": 0.5402437489124795,
"accuracy_std": 0.22467627690136344
},
"toxic": {
"accuracy_mean": 0.9490006447453256,
"accuracy_std": 0.06391582554207781
}
},
"overall": {
"size": {
"n_samples": 432,
"mse": 0.19167728610863985,
"rmse": 0.43780964597486866,
"mae": 0.2816500812768936,
"r2": -0.04410163027802061
},
"delivery": {
"n_samples": 310,
"mse": 1.095306046771274,
"rmse": 1.0465687014101244,
"mae": 0.6143080417337197,
"r2": -0.1024184074306409
},
"pdi": {
"n_samples": 434,
"accuracy": 0.7396313364055299
},
"ee": {
"n_samples": 434,
"accuracy": 0.4976958525345622
},
"toxic": {
"n_samples": 313,
"accuracy": 0.9552715654952076
}
}
}

226
scripts/process_data_cv.py Normal file
View File

@ -0,0 +1,226 @@
"""内部数据 Cross-Validation 划分脚本:基于 Amine 的分组划分"""
from pathlib import Path
from typing import List
import numpy as np
import pandas as pd
import typer
from loguru import logger
from lnp_ml.config import INTERIM_DATA_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import (
process_dataframe,
SMILES_COL,
COMP_COLS,
HELP_COLS,
TARGET_REGRESSION,
TARGET_CLASSIFICATION_PDI,
TARGET_CLASSIFICATION_EE,
TARGET_TOXIC,
TARGET_BIODIST,
get_phys_cols,
get_exp_cols,
)
app = typer.Typer()
def amine_based_cv_split(
df: pd.DataFrame,
n_folds: int = 5,
seed: int = 42,
amine_col: str = "Amine",
) -> List[dict]:
"""
基于 Amine 列进行 Cross-Validation 划分
步骤
1. amine_col 分组
2. 打乱分组顺序
3. 将分组 round-robin 分配到 n_folds 个容器
4. 对于每个 fold i
- validation = container[i]
- test = container[(i+1) % n_folds]
- train = 其余所有
Args:
df: 输入 DataFrame
n_folds: 折数
seed: 随机种子
amine_col: 用于分组的列名
Returns:
List of dicts每个 dict 包含 train_df, val_df, test_df
"""
# 获取唯一的 amine 并打乱
unique_amines = df[amine_col].unique()
rng = np.random.RandomState(seed)
rng.shuffle(unique_amines)
logger.info(f"Found {len(unique_amines)} unique amines")
# Round-robin 分配到 n_folds 个容器
containers = [[] for _ in range(n_folds)]
for i, amine in enumerate(unique_amines):
containers[i % n_folds].append(amine)
# 打印每个容器的大小
for i, container in enumerate(containers):
container_samples = df[df[amine_col].isin(container)]
logger.info(f" Container {i}: {len(container)} amines, {len(container_samples)} samples")
# 生成每个 fold 的数据
fold_splits = []
for i in range(n_folds):
val_amines = set(containers[i])
test_amines = set(containers[(i + 1) % n_folds])
train_amines = set()
for j in range(n_folds):
if j != i and j != (i + 1) % n_folds:
train_amines.update(containers[j])
train_df = df[df[amine_col].isin(train_amines)].reset_index(drop=True)
val_df = df[df[amine_col].isin(val_amines)].reset_index(drop=True)
test_df = df[df[amine_col].isin(test_amines)].reset_index(drop=True)
fold_splits.append({
"train": train_df,
"val": val_df,
"test": test_df,
})
logger.info(
f"Fold {i}: train={len(train_df)} ({len(train_amines)} amines), "
f"val={len(val_df)} ({len(val_amines)} amines), "
f"test={len(test_df)} ({len(test_amines)} amines)"
)
return fold_splits
@app.command()
def main(
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
output_dir: Path = PROCESSED_DATA_DIR / "cv",
n_folds: int = 5,
seed: int = 42,
amine_col: str = "Amine",
):
"""
基于 Amine 分组进行 Cross-Validation 数据划分
采用类似 scaffold splitting 的思路将相同 Amine 的数据放在同一组
确保训练集和测试集之间没有 Amine 泄露
划分比例约为 train:val:test 3:1:1
输出结构:
- processed/cv/fold_0/train.parquet
- processed/cv/fold_0/val.parquet
- processed/cv/fold_0/test.parquet
- processed/cv/fold_1/...
- processed/cv/feature_columns.txt
"""
logger.info(f"Loading data from {input_path}")
df = pd.read_csv(input_path)
logger.info(f"Loaded {len(df)} samples")
# 检查 amine 列是否存在
if amine_col not in df.columns:
logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(df.columns)}")
raise typer.Exit(1)
# 处理数据列对齐、one-hot 生成等)
logger.info("Processing dataframe...")
df = process_dataframe(df)
# 确保 Amine 列仍然存在process_dataframe 可能不会保留它)
# 重新加载原始数据获取 Amine 列
original_df = pd.read_csv(input_path)
if amine_col in original_df.columns and amine_col not in df.columns:
df[amine_col] = original_df[amine_col].values
# 定义要保留的列
phys_cols = get_phys_cols()
exp_cols = get_exp_cols()
keep_cols = (
[SMILES_COL]
+ COMP_COLS
+ phys_cols
+ HELP_COLS
+ exp_cols
+ TARGET_REGRESSION
+ TARGET_CLASSIFICATION_PDI
+ TARGET_CLASSIFICATION_EE
+ [TARGET_TOXIC]
+ TARGET_BIODIST
)
# 只保留存在的列
keep_cols = [c for c in keep_cols if c in df.columns]
# 进行 CV 划分
logger.info(f"\nPerforming {n_folds}-fold amine-based CV split (seed={seed})...")
fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col)
# 保存每个 fold
output_dir.mkdir(parents=True, exist_ok=True)
for i, split in enumerate(fold_splits):
fold_dir = output_dir / f"fold_{i}"
fold_dir.mkdir(parents=True, exist_ok=True)
# 只保留需要的列
train_df = split["train"][keep_cols].reset_index(drop=True)
val_df = split["val"][keep_cols].reset_index(drop=True)
test_df = split["test"][keep_cols].reset_index(drop=True)
# 保存
train_df.to_parquet(fold_dir / "train.parquet", index=False)
val_df.to_parquet(fold_dir / "val.parquet", index=False)
test_df.to_parquet(fold_dir / "test.parquet", index=False)
logger.success(f"Saved fold {i} to {fold_dir}")
# 保存列名配置
config_path = output_dir / "feature_columns.txt"
with open(config_path, "w") as f:
f.write("# Feature columns configuration\n\n")
f.write(f"# SMILES\n{SMILES_COL}\n\n")
f.write(f"# comp token [{len(COMP_COLS)}]\n")
f.write("\n".join(COMP_COLS) + "\n\n")
f.write(f"# phys token [{len(phys_cols)}]\n")
f.write("\n".join(phys_cols) + "\n\n")
f.write(f"# help token [{len(HELP_COLS)}]\n")
f.write("\n".join(HELP_COLS) + "\n\n")
f.write(f"# exp token [{len(exp_cols)}]\n")
f.write("\n".join(exp_cols) + "\n\n")
f.write("# Targets\n")
f.write("## Regression\n")
f.write("\n".join(TARGET_REGRESSION) + "\n")
f.write("## PDI classification\n")
f.write("\n".join(TARGET_CLASSIFICATION_PDI) + "\n")
f.write("## EE classification\n")
f.write("\n".join(TARGET_CLASSIFICATION_EE) + "\n")
f.write("## Toxic\n")
f.write(f"{TARGET_TOXIC}\n")
f.write("## Biodistribution\n")
f.write("\n".join(TARGET_BIODIST) + "\n")
logger.success(f"Saved feature config to {config_path}")
# 打印汇总
logger.info("\n" + "=" * 60)
logger.info("CV DATA PROCESSING COMPLETE")
logger.info("=" * 60)
logger.info(f"Output directory: {output_dir}")
logger.info(f"Number of folds: {n_folds}")
logger.info(f"Splitting method: Amine-based (column: {amine_col})")
logger.info(f"Random seed: {seed}")
if __name__ == "__main__":
app()

View File

@ -151,18 +151,18 @@ def get_feature_columns() -> List[str]:
@app.command()
def main(
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
output_dir: Path = PROCESSED_DATA_DIR / "cv",
output_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
n_folds: int = 5,
):
"""
处理 cross-validation 数据生成模型所需的 parquet 文件
输出结构:
- processed/cv/fold_0/train.parquet
- processed/cv/fold_0/valid.parquet
- processed/cv/fold_0/test.parquet
- processed/cv/fold_1/...
- processed/cv/feature_columns.txt
- processed/pretrain_cv/fold_0/train.parquet
- processed/pretrain_cv/fold_0/valid.parquet
- processed/pretrain_cv/fold_0/test.parquet
- processed/pretrain_cv/fold_1/...
- processed/pretrain_cv/feature_columns.txt
"""
logger.info(f"Processing CV data from {data_dir}")