Compare commits

...

4 Commits

Author SHA1 Message Date
RYDE-WORK
ac4246c2b7 Add train_cv(without pretrain) 2026-01-22 18:06:13 +08:00
RYDE-WORK
47bbb64c66 Add more metrics 2026-01-22 17:06:24 +08:00
RYDE-WORK
039be54c5a ... 2026-01-22 01:01:29 +08:00
RYDE-WORK
e6a5e5495a ... 2026-01-22 00:24:13 +08:00
54 changed files with 4073 additions and 401 deletions

View File

@ -74,6 +74,11 @@ data_pretrain: requirements
$(PYTHON_INTERPRETER) scripts/process_external.py
## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv)
.PHONY: data_pretrain_cv
data_pretrain_cv: requirements
$(PYTHON_INTERPRETER) scripts/process_external_cv.py
## Process internal data with amine-based CV splitting (interim -> processed/cv)
.PHONY: data_cv
data_cv: requirements
$(PYTHON_INTERPRETER) scripts/process_data_cv.py
@ -106,8 +111,8 @@ pretrain_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG)
## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint)
.PHONY: test_cv
test_cv: requirements
.PHONY: test_pretrain_cv
test_pretrain_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG)
## Train model (multi-task, from scratch)
@ -120,6 +125,22 @@ train: requirements
finetune: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
## Finetune with cross-validation on internal data (5-fold, amine-based split) with pretrained weights
.PHONY: finetune_cv
finetune_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv main --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
## Train with cross-validation on internal data only (5-fold, amine-based split)
.PHONY: train_cv
train_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv main $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
## Evaluate CV finetuned models on test sets (auto-detects MPNN from checkpoint)
.PHONY: test_cv
test_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv test $(DEVICE_FLAG)
## Train with hyperparameter tuning
.PHONY: tune
tune: requirements

View File

@ -1,9 +1,16 @@
# Feature columns configuration
# SMILES
smiles
# comp token [5]
Cationic_Lipid_to_mRNA_weight_ratio
Cationic_Lipid_Mol_Ratio
Phospholipid_Mol_Ratio
Cholesterol_Mol_Ratio
PEG_Lipid_Mol_Ratio
# phys token [12]
Purity_Pure
Purity_Crude
Mix_type_Microfluidic
@ -16,10 +23,14 @@ Target_or_delivered_gene_Peptide_barcode
Target_or_delivered_gene_hEPO
Target_or_delivered_gene_FVII
Target_or_delivered_gene_GFP
# help token [4]
Helper_lipid_ID_DOPE
Helper_lipid_ID_DOTAP
Helper_lipid_ID_DSPC
Helper_lipid_ID_MDOA
# exp token [32]
Model_type_A549
Model_type_BDMC
Model_type_BMDM
@ -52,4 +63,27 @@ Value_name_hEPO
Value_name_FVII_silencing
Value_name_GFP_delivery
Value_name_Discretized_luminescence
# Targets
## Regression
size
quantified_delivery
## PDI classification
PDI_0_0to0_2
PDI_0_2to0_3
PDI_0_3to0_4
PDI_0_4to0_5
## EE classification
Encapsulation_Efficiency_EE<50
Encapsulation_Efficiency_50<=EE<80
Encapsulation_Efficiency_80<EE<=100
## Toxic
toxic
## Biodistribution
Biodistribution_lymph_nodes
Biodistribution_heart
Biodistribution_liver
Biodistribution_spleen
Biodistribution_lung
Biodistribution_kidney
Biodistribution_muscle

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,55 @@
smiles
Cationic_Lipid_to_mRNA_weight_ratio
Cationic_Lipid_Mol_Ratio
Phospholipid_Mol_Ratio
Cholesterol_Mol_Ratio
PEG_Lipid_Mol_Ratio
Purity_Pure
Purity_Crude
Mix_type_Microfluidic
Mix_type_Pipetting
Cargo_type_mRNA
Cargo_type_pDNA
Cargo_type_siRNA
Target_or_delivered_gene_FFL
Target_or_delivered_gene_Peptide_barcode
Target_or_delivered_gene_hEPO
Target_or_delivered_gene_FVII
Target_or_delivered_gene_GFP
Helper_lipid_ID_DOPE
Helper_lipid_ID_DOTAP
Helper_lipid_ID_DSPC
Helper_lipid_ID_MDOA
Model_type_A549
Model_type_BDMC
Model_type_BMDM
Model_type_HBEC_ALI
Model_type_HEK293T
Model_type_HeLa
Model_type_IGROV1
Model_type_Mouse
Model_type_RAW264p7
Delivery_target_body
Delivery_target_dendritic_cell
Delivery_target_generic_cell
Delivery_target_liver
Delivery_target_lung
Delivery_target_lung_epithelium
Delivery_target_macrophage
Delivery_target_muscle
Delivery_target_spleen
Route_of_administration_in_vitro
Route_of_administration_intramuscular
Route_of_administration_intratracheal
Route_of_administration_intravenous
Batch_or_individual_or_barcoded_Barcoded
Batch_or_individual_or_barcoded_Individual
Value_name_log_luminescence
Value_name_luminescence
Value_name_FFL_silencing
Value_name_Peptide_abundance
Value_name_hEPO
Value_name_FVII_silencing
Value_name_GFP_delivery
Value_name_Discretized_luminescence
quantified_delivery

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,37 +1,51 @@
{
"loss_metrics": {
"loss": 2.8661977450052896,
"loss_size": 0.44916408757368725,
"loss_pdi": 0.5041926403840383,
"loss_ee": 0.9021427234013876,
"loss_delivery": 0.5761533578236898,
"loss_biodist": 0.4019051690896352,
"loss_toxic": 0.03263980595511384,
"acc_pdi": 0.7633587786259542,
"acc_ee": 0.6641221374045801,
"acc_toxic": 0.9702970297029703
"loss": 2.5374555587768555,
"loss_size": 0.1886825958887736,
"loss_pdi": 0.45798932512601215,
"loss_ee": 0.829658567905426,
"loss_delivery": 0.4857304096221924,
"loss_biodist": 0.5346279243628184,
"loss_toxic": 0.04076674363265435,
"acc_pdi": 0.7862595419847328,
"acc_ee": 0.6793893129770993,
"acc_toxic": 0.9801980198019802
},
"detailed_metrics": {
"size": {
"mse": 0.41126506251447736,
"rmse": 0.6412995107704959,
"mae": 0.41415552388095633,
"r2": -0.9333718010891026
"mse": 0.1669999969286325,
"rmse": 0.4086563310761654,
"mae": 0.26111859684375066,
"r2": 0.2149270281561566
},
"delivery": {
"mse": 0.6277965050686476,
"rmse": 0.7923361061245711,
"mae": 0.5387302115022443,
"r2": 0.24206702565575944
"mse": 0.5193460523366603,
"rmse": 0.7206566813238189,
"mae": 0.4828052782115008,
"r2": 0.37299826459145
},
"pdi": {
"accuracy": 0.7633587786259542
"accuracy": 0.7862595419847328,
"precision": 0.7282763532763532,
"recall": 0.6907738095238095,
"f1": 0.7041935483870968
},
"ee": {
"accuracy": 0.6641221374045801
"accuracy": 0.6793893129770993,
"precision": 0.612247574088644,
"recall": 0.6062951496388029,
"f1": 0.6069449904342585
},
"toxic": {
"accuracy": 0.9702970297029703
"accuracy": 0.9801980198019802,
"precision": 0.5,
"recall": 0.4900990099009901,
"f1": 0.495
},
"biodist": {
"n_samples": 101,
"kl_divergence": 0.2931957937514963,
"js_divergence": 0.07706768601895059
}
}
}

View File

@ -217,15 +217,31 @@ def test(
"""
import json
import numpy as np
from scipy.special import rel_entr
from sklearn.metrics import (
mean_squared_error,
mean_absolute_error,
r2_score,
accuracy_score,
classification_report,
precision_score,
recall_score,
f1_score,
)
from lnp_ml.modeling.trainer import validate
def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
"""计算 KL 散度 KL(p || q)"""
p = np.clip(p, eps, 1.0)
q = np.clip(q, eps, 1.0)
return float(np.sum(rel_entr(p, q), axis=-1).mean())
def js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
"""计算 JS 散度"""
p = np.clip(p, eps, 1.0)
q = np.clip(q, eps, 1.0)
m = 0.5 * (p + q)
return float(0.5 * (np.sum(rel_entr(p, m), axis=-1) + np.sum(rel_entr(q, m), axis=-1)).mean())
logger.info(f"Using device: {device}")
device_obj = torch.device(device)
@ -287,6 +303,9 @@ def test(
y_pred = np.array(predictions["pdi"])[mask]
results["detailed_metrics"]["pdi"] = {
"accuracy": float(accuracy_score(y_true, y_pred)),
"precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)),
"recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)),
"f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
}
# 分类指标EE
@ -299,6 +318,9 @@ def test(
y_pred = np.array(predictions["ee"])[mask]
results["detailed_metrics"]["ee"] = {
"accuracy": float(accuracy_score(y_true, y_pred)),
"precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)),
"recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)),
"f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
}
# 分类指标toxic
@ -309,6 +331,28 @@ def test(
y_pred = np.array(predictions["toxic"])[mask.values]
results["detailed_metrics"]["toxic"] = {
"accuracy": float(accuracy_score(y_true, y_pred)),
"precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)),
"recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)),
"f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
}
# 分布指标biodist
biodist_cols = [
"Biodistribution_lymph_nodes", "Biodistribution_heart", "Biodistribution_liver",
"Biodistribution_spleen", "Biodistribution_lung", "Biodistribution_kidney", "Biodistribution_muscle"
]
if all(c in test_df.columns for c in biodist_cols):
biodist_true = test_df[biodist_cols].values
biodist_pred = np.array(predictions["biodist"])
# mask: 有效样本是 sum > 0 且无 NaN
mask = (biodist_true.sum(axis=1) > 0) & (~np.isnan(biodist_true).any(axis=1))
if mask.any():
y_true = biodist_true[mask]
y_pred = biodist_pred[mask]
results["detailed_metrics"]["biodist"] = {
"n_samples": int(mask.sum()),
"kl_divergence": kl_divergence(y_true, y_pred),
"js_divergence": js_divergence(y_true, y_pred),
}
# 打印结果

View File

@ -271,7 +271,7 @@ def create_model(
@app.command()
def main(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
output_dir: Path = MODELS_DIR / "pretrain_cv",
# 模型参数
d_model: int = 256,
@ -322,7 +322,7 @@ def main(
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
logger.info("Please run 'make data_cv' first to process CV data.")
logger.info("Please run 'make data_pretrain_cv' first to process CV data.")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
@ -464,7 +464,7 @@ def main(
@app.command()
def test(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
model_dir: Path = MODELS_DIR / "pretrain_cv",
output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json",
batch_size: int = 64,

720
lnp_ml/modeling/train_cv.py Normal file
View File

@ -0,0 +1,720 @@
"""Cross-Validation 训练脚本:在 5-fold 内部数据上进行多任务训练"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from loguru import logger
from tqdm import tqdm
import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import LNPDataset, collate_fn
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
from lnp_ml.modeling.trainer import (
train_epoch,
validate,
EarlyStopping,
LossWeights,
)
# MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
if not model_paths:
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
return [str(p) for p in model_paths]
app = typer.Typer()
def create_model(
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
mpnn_checkpoint: Optional[str] = None,
mpnn_ensemble_paths: Optional[List[str]] = None,
mpnn_device: str = "cpu",
) -> Union[LNPModel, LNPModelWithoutMPNN]:
"""创建模型(支持可选的 MPNN encoder"""
use_mpnn = mpnn_checkpoint is not None or mpnn_ensemble_paths is not None
if use_mpnn:
return LNPModel(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_checkpoint=mpnn_checkpoint,
mpnn_ensemble_paths=mpnn_ensemble_paths,
mpnn_device=mpnn_device,
)
else:
return LNPModelWithoutMPNN(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
)
def train_fold(
fold_idx: int,
train_loader: DataLoader,
val_loader: DataLoader,
model: nn.Module,
device: torch.device,
output_dir: Path,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 100,
patience: int = 15,
loss_weights: Optional[LossWeights] = None,
config: Optional[Dict] = None,
) -> Dict:
"""训练单个 fold"""
logger.info(f"\n{'='*60}")
logger.info(f"Training Fold {fold_idx}")
logger.info(f"{'='*60}")
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
early_stopping = EarlyStopping(patience=patience)
history = {"train": [], "val": []}
best_val_loss = float("inf")
best_state = None
for epoch in range(epochs):
# Train
train_metrics = train_epoch(model, train_loader, optimizer, device, loss_weights)
# Validate
val_metrics = validate(model, val_loader, device, loss_weights)
current_lr = optimizer.param_groups[0]["lr"]
# Log
logger.info(
f"Fold {fold_idx} Epoch {epoch+1}/{epochs} | "
f"Train Loss: {train_metrics['loss']:.4f} | "
f"Val Loss: {val_metrics['loss']:.4f} | "
f"LR: {current_lr:.2e}"
)
history["train"].append(train_metrics)
history["val"].append(val_metrics)
# Learning rate scheduling
scheduler.step(val_metrics["loss"])
# Save best model
if val_metrics["loss"] < best_val_loss:
best_val_loss = val_metrics["loss"]
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
logger.info(f" -> New best model (val_loss={best_val_loss:.4f})")
# Early stopping
if early_stopping(val_metrics["loss"]):
logger.info(f"Early stopping at epoch {epoch+1}")
break
# 保存最佳模型
fold_output_dir = output_dir / f"fold_{fold_idx}"
fold_output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = fold_output_dir / "model.pt"
torch.save({
"model_state_dict": best_state,
"config": config,
"best_val_loss": best_val_loss,
"fold_idx": fold_idx,
}, checkpoint_path)
logger.success(f"Saved fold {fold_idx} model to {checkpoint_path}")
# 保存训练历史
history_path = fold_output_dir / "history.json"
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
return {
"fold_idx": fold_idx,
"best_val_loss": best_val_loss,
"epochs_trained": len(history["train"]),
"final_train_loss": history["train"][-1]["loss"] if history["train"] else 0,
}
@app.command()
def main(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
output_dir: Path = MODELS_DIR / "finetune_cv",
# 模型参数
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
# MPNN 参数(可选)
use_mpnn: bool = False,
mpnn_checkpoint: Optional[str] = None,
mpnn_ensemble_paths: Optional[str] = None,
mpnn_device: str = "cpu",
# 训练参数
batch_size: int = 32,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 100,
patience: int = 15,
# 预训练权重加载
init_from_pretrain: Optional[Path] = None,
load_delivery_head: bool = True,
freeze_backbone: bool = False,
# 设备
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
基于 Cross-Validation 训练 LNP 模型多任务
5-fold 内部数据上训练 5 个模型
使用 --use-mpnn 启用 MPNN encoder
使用 --init-from-pretrain 从预训练 checkpoint 初始化
使用 --freeze-backbone 冻结 backbone只训练 heads
"""
logger.info(f"Using device: {device}")
device = torch.device(device)
# 查找所有 fold 目录
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
logger.info("Please run 'make data_cv' first to process CV data.")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
output_dir.mkdir(parents=True, exist_ok=True)
# 解析 MPNN 配置
ensemble_paths_list = None
if mpnn_ensemble_paths:
ensemble_paths_list = mpnn_ensemble_paths.split(",")
elif use_mpnn and mpnn_checkpoint is None:
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
ensemble_paths_list = find_mpnn_ensemble_paths()
logger.info(f"Found {len(ensemble_paths_list)} MPNN models")
enable_mpnn = mpnn_checkpoint is not None or ensemble_paths_list is not None
# 模型配置
config = {
"d_model": d_model,
"num_heads": num_heads,
"n_attn_layers": n_attn_layers,
"fusion_strategy": fusion_strategy,
"head_hidden_dim": head_hidden_dim,
"dropout": dropout,
"use_mpnn": enable_mpnn,
"lr": lr,
"weight_decay": weight_decay,
"batch_size": batch_size,
"epochs": epochs,
"patience": patience,
"init_from_pretrain": str(init_from_pretrain) if init_from_pretrain else None,
"freeze_backbone": freeze_backbone,
}
# 保存配置
config_path = output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
logger.info(f"Saved config to {config_path}")
# 加载预训练权重(如果指定)
pretrain_state = None
if init_from_pretrain is not None:
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
pretrain_config = checkpoint.get("config", {})
if pretrain_config.get("d_model") != d_model:
logger.warning(
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
f"current={d_model}. Skipping pretrain loading."
)
else:
pretrain_state = checkpoint["model_state_dict"]
# 训练每个 fold
fold_results = []
for fold_dir in tqdm(fold_dirs, desc="Training folds"):
fold_idx = int(fold_dir.name.split("_")[1])
# 加载数据
train_df = pd.read_parquet(fold_dir / "train.parquet")
val_df = pd.read_parquet(fold_dir / "val.parquet")
logger.info(f"\nFold {fold_idx}: train={len(train_df)}, val={len(val_df)}")
# 创建 Dataset 和 DataLoader
train_dataset = LNPDataset(train_df)
val_dataset = LNPDataset(val_df)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 创建新模型(每个 fold 独立初始化)
model = create_model(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_checkpoint=mpnn_checkpoint,
mpnn_ensemble_paths=ensemble_paths_list,
mpnn_device=device.type,
)
# 加载预训练权重
if pretrain_state is not None:
model.load_pretrain_weights(
pretrain_state_dict=pretrain_state,
load_delivery_head=load_delivery_head,
strict=False,
)
logger.info(f"Loaded pretrain weights (backbone + delivery_head={load_delivery_head})")
# 冻结 backbone如果指定
if freeze_backbone:
frozen_count = 0
for name, param in model.named_parameters():
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
param.requires_grad = False
frozen_count += 1
logger.info(f"Frozen {frozen_count} parameter tensors")
# 打印模型信息(仅第一个 fold
if fold_idx == 0:
n_params_total = sum(p.numel() for p in model.parameters())
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
# 训练
result = train_fold(
fold_idx=fold_idx,
train_loader=train_loader,
val_loader=val_loader,
model=model,
device=device,
output_dir=output_dir,
lr=lr,
weight_decay=weight_decay,
epochs=epochs,
patience=patience,
config=config,
)
fold_results.append(result)
# 汇总结果
logger.info("\n" + "=" * 60)
logger.info("CROSS-VALIDATION TRAINING COMPLETE")
logger.info("=" * 60)
val_losses = [r["best_val_loss"] for r in fold_results]
logger.info(f"\n[Per-Fold Results]")
for r in fold_results:
logger.info(
f" Fold {r['fold_idx']}: "
f"Val Loss={r['best_val_loss']:.4f}, "
f"Epochs={r['epochs_trained']}"
)
logger.info(f"\n[Summary Statistics]")
logger.info(f" Val Loss: {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}")
# 保存 CV 结果
cv_results = {
"fold_results": fold_results,
"summary": {
"val_loss_mean": float(np.mean(val_losses)),
"val_loss_std": float(np.std(val_losses)),
},
"config": config,
}
results_path = output_dir / "cv_results.json"
with open(results_path, "w") as f:
json.dump(cv_results, f, indent=2)
logger.success(f"Saved CV results to {results_path}")
@app.command()
def test(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
model_dir: Path = MODELS_DIR / "finetune_cv",
output_path: Path = MODELS_DIR / "finetune_cv" / "test_results.json",
batch_size: int = 64,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
在测试集上评估 CV 训练的模型
使用每个 fold 的模型在对应的测试集上评估然后汇总结果
"""
from scipy.special import rel_entr
from sklearn.metrics import (
mean_squared_error,
mean_absolute_error,
r2_score,
accuracy_score,
precision_score,
recall_score,
f1_score,
)
def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
"""计算 KL 散度 KL(p || q)"""
p = np.clip(p, eps, 1.0)
q = np.clip(q, eps, 1.0)
return float(np.sum(rel_entr(p, q), axis=-1).mean())
def js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
"""计算 JS 散度"""
p = np.clip(p, eps, 1.0)
q = np.clip(q, eps, 1.0)
m = 0.5 * (p + q)
return float(0.5 * (np.sum(rel_entr(p, m), axis=-1) + np.sum(rel_entr(q, m), axis=-1)).mean())
logger.info(f"Using device: {device}")
device = torch.device(device)
# 查找所有 fold 目录
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds")
fold_results = []
# 用于汇总所有 fold 的预测
all_preds = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
}
all_targets = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
}
for fold_dir in tqdm(fold_dirs, desc="Evaluating folds"):
fold_idx = int(fold_dir.name.split("_")[1])
model_path = model_dir / f"fold_{fold_idx}" / "model.pt"
test_path = fold_dir / "test.parquet"
if not model_path.exists():
logger.warning(f"Fold {fold_idx}: model not found at {model_path}, skipping")
continue
if not test_path.exists():
logger.warning(f"Fold {fold_idx}: test data not found at {test_path}, skipping")
continue
# 加载模型
checkpoint = torch.load(model_path, map_location=device)
config = checkpoint["config"]
use_mpnn = config.get("use_mpnn", False)
# 总是重新查找 MPNN 路径
if use_mpnn:
mpnn_paths = find_mpnn_ensemble_paths()
else:
mpnn_paths = None
model = create_model(
d_model=config["d_model"],
num_heads=config["num_heads"],
n_attn_layers=config["n_attn_layers"],
fusion_strategy=config["fusion_strategy"],
head_hidden_dim=config["head_hidden_dim"],
dropout=config["dropout"],
mpnn_ensemble_paths=mpnn_paths,
mpnn_device=device.type,
)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
model.eval()
# 加载测试数据
test_df = pd.read_parquet(test_path)
test_dataset = LNPDataset(test_df)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 收集当前 fold 的预测
fold_preds = {k: [] for k in all_preds.keys()}
fold_targets = {k: [] for k in all_targets.keys()}
with torch.no_grad():
pbar = tqdm(test_loader, desc=f"Fold {fold_idx} [Test]", leave=False)
for batch in pbar:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = batch["targets"]
masks = batch["mask"]
outputs = model(smiles, tabular)
# Size
if "size" in masks and masks["size"].any():
mask = masks["size"]
fold_preds["size"].extend(
outputs["size"].squeeze(-1)[mask].cpu().numpy().tolist()
)
fold_targets["size"].extend(
targets["size"][mask].cpu().numpy().tolist()
)
# Delivery
if "delivery" in masks and masks["delivery"].any():
mask = masks["delivery"]
fold_preds["delivery"].extend(
outputs["delivery"].squeeze(-1)[mask].cpu().numpy().tolist()
)
fold_targets["delivery"].extend(
targets["delivery"][mask].cpu().numpy().tolist()
)
# PDI (classification)
if "pdi" in masks and masks["pdi"].any():
mask = masks["pdi"]
pdi_preds = outputs["pdi"][mask].argmax(dim=-1).cpu().numpy()
pdi_targets = targets["pdi"][mask].cpu().numpy()
fold_preds["pdi"].extend(pdi_preds.tolist())
fold_targets["pdi"].extend(pdi_targets.tolist())
# EE (classification)
if "ee" in masks and masks["ee"].any():
mask = masks["ee"]
ee_preds = outputs["ee"][mask].argmax(dim=-1).cpu().numpy()
ee_targets = targets["ee"][mask].cpu().numpy()
fold_preds["ee"].extend(ee_preds.tolist())
fold_targets["ee"].extend(ee_targets.tolist())
# Toxic (classification)
if "toxic" in masks and masks["toxic"].any():
mask = masks["toxic"]
toxic_preds = outputs["toxic"][mask].argmax(dim=-1).cpu().numpy()
toxic_targets = targets["toxic"][mask].cpu().numpy().astype(int)
fold_preds["toxic"].extend(toxic_preds.tolist())
fold_targets["toxic"].extend(toxic_targets.tolist())
# Biodist (distribution)
if "biodist" in masks and masks["biodist"].any():
mask = masks["biodist"]
biodist_preds = outputs["biodist"][mask].cpu().numpy()
biodist_targets = targets["biodist"][mask].cpu().numpy()
fold_preds["biodist"].extend(biodist_preds.tolist())
fold_targets["biodist"].extend(biodist_targets.tolist())
# 计算当前 fold 的指标
fold_metrics = {"fold_idx": fold_idx, "n_samples": len(test_df)}
# 回归任务指标
for task in ["size", "delivery"]:
if fold_preds[task]:
p = np.array(fold_preds[task])
t = np.array(fold_targets[task])
fold_metrics[task] = {
"n": len(p),
"rmse": float(np.sqrt(mean_squared_error(t, p))),
"mae": float(mean_absolute_error(t, p)),
"r2": float(r2_score(t, p)),
}
# 分类任务指标
for task in ["pdi", "ee", "toxic"]:
if fold_preds[task]:
p = np.array(fold_preds[task])
t = np.array(fold_targets[task])
fold_metrics[task] = {
"n": len(p),
"accuracy": float(accuracy_score(t, p)),
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
}
# 分布任务指标
if fold_preds["biodist"]:
p = np.array(fold_preds["biodist"])
t = np.array(fold_targets["biodist"])
fold_metrics["biodist"] = {
"n": len(p),
"kl_divergence": kl_divergence(t, p),
"js_divergence": js_divergence(t, p),
}
fold_results.append(fold_metrics)
# 汇总到全局
for task in all_preds.keys():
all_preds[task].extend(fold_preds[task])
all_targets[task].extend(fold_targets[task])
# 打印当前 fold 结果
log_parts = [f"Fold {fold_idx}: n={len(test_df)}"]
for task in ["delivery", "size"]:
if task in fold_metrics and isinstance(fold_metrics[task], dict):
log_parts.append(f"{task}_RMSE={fold_metrics[task]['rmse']:.4f}")
log_parts.append(f"{task}_R²={fold_metrics[task]['r2']:.4f}")
for task in ["pdi", "ee", "toxic"]:
if task in fold_metrics and isinstance(fold_metrics[task], dict):
log_parts.append(f"{task}_acc={fold_metrics[task]['accuracy']:.4f}")
log_parts.append(f"{task}_f1={fold_metrics[task]['f1']:.4f}")
if "biodist" in fold_metrics and isinstance(fold_metrics["biodist"], dict):
log_parts.append(f"biodist_KL={fold_metrics['biodist']['kl_divergence']:.4f}")
log_parts.append(f"biodist_JS={fold_metrics['biodist']['js_divergence']:.4f}")
logger.info(", ".join(log_parts))
# 计算跨 fold 汇总统计
summary_stats = {}
for task in ["size", "delivery"]:
rmses = [r[task]["rmse"] for r in fold_results if task in r and isinstance(r[task], dict)]
r2s = [r[task]["r2"] for r in fold_results if task in r and isinstance(r[task], dict)]
if rmses:
summary_stats[task] = {
"rmse_mean": float(np.mean(rmses)),
"rmse_std": float(np.std(rmses)),
"r2_mean": float(np.mean(r2s)),
"r2_std": float(np.std(r2s)),
}
for task in ["pdi", "ee", "toxic"]:
accs = [r[task]["accuracy"] for r in fold_results if task in r and isinstance(r[task], dict)]
f1s = [r[task]["f1"] for r in fold_results if task in r and isinstance(r[task], dict)]
if accs:
summary_stats[task] = {
"accuracy_mean": float(np.mean(accs)),
"accuracy_std": float(np.std(accs)),
"f1_mean": float(np.mean(f1s)),
"f1_std": float(np.std(f1s)),
}
# 分布任务汇总
kls = [r["biodist"]["kl_divergence"] for r in fold_results if "biodist" in r and isinstance(r["biodist"], dict)]
jss = [r["biodist"]["js_divergence"] for r in fold_results if "biodist" in r and isinstance(r["biodist"], dict)]
if kls:
summary_stats["biodist"] = {
"kl_mean": float(np.mean(kls)),
"kl_std": float(np.std(kls)),
"js_mean": float(np.mean(jss)),
"js_std": float(np.std(jss)),
}
# 计算整体 pooled 指标
overall = {}
for task in ["size", "delivery"]:
if all_preds[task]:
p = np.array(all_preds[task])
t = np.array(all_targets[task])
overall[task] = {
"n_samples": len(p),
"mse": float(mean_squared_error(t, p)),
"rmse": float(np.sqrt(mean_squared_error(t, p))),
"mae": float(mean_absolute_error(t, p)),
"r2": float(r2_score(t, p)),
}
for task in ["pdi", "ee", "toxic"]:
if all_preds[task]:
p = np.array(all_preds[task])
t = np.array(all_targets[task])
overall[task] = {
"n_samples": len(p),
"accuracy": float(accuracy_score(t, p)),
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
}
# 分布任务
if all_preds["biodist"]:
p = np.array(all_preds["biodist"])
t = np.array(all_targets["biodist"])
overall["biodist"] = {
"n_samples": len(p),
"kl_divergence": kl_divergence(t, p),
"js_divergence": js_divergence(t, p),
}
# 打印汇总结果
logger.info("\n" + "=" * 60)
logger.info("CV TEST EVALUATION RESULTS")
logger.info("=" * 60)
logger.info(f"\n[Summary Statistics (across {len(fold_results)} folds)]")
for task, stats in summary_stats.items():
if "rmse_mean" in stats:
logger.info(f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}")
elif "accuracy_mean" in stats:
logger.info(f" {task}: Accuracy={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}, F1={stats['f1_mean']:.4f}±{stats['f1_std']:.4f}")
elif "kl_mean" in stats:
logger.info(f" {task}: KL={stats['kl_mean']:.4f}±{stats['kl_std']:.4f}, JS={stats['js_mean']:.4f}±{stats['js_std']:.4f}")
logger.info(f"\n[Overall (all samples pooled)]")
for task, metrics in overall.items():
if "rmse" in metrics:
logger.info(f" {task} (n={metrics['n_samples']}): RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, R²={metrics['r2']:.4f}")
elif "accuracy" in metrics:
logger.info(f" {task} (n={metrics['n_samples']}): Accuracy={metrics['accuracy']:.4f}, Precision={metrics['precision']:.4f}, Recall={metrics['recall']:.4f}, F1={metrics['f1']:.4f}")
elif "kl_divergence" in metrics:
logger.info(f" {task} (n={metrics['n_samples']}): KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}")
# 保存结果
results = {
"fold_results": fold_results,
"summary_stats": summary_stats,
"overall": overall,
}
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
logger.success(f"\nSaved test results to {output_path}")
if __name__ == "__main__":
app()

View File

@ -0,0 +1,16 @@
{
"d_model": 256,
"num_heads": 8,
"n_attn_layers": 4,
"fusion_strategy": "attention",
"head_hidden_dim": 128,
"dropout": 0.1,
"use_mpnn": true,
"lr": 0.0001,
"weight_decay": 1e-05,
"batch_size": 32,
"epochs": 100,
"patience": 15,
"init_from_pretrain": null,
"freeze_backbone": false
}

View File

@ -0,0 +1,54 @@
{
"fold_results": [
{
"fold_idx": 0,
"best_val_loss": 5.7676777839660645,
"epochs_trained": 24,
"final_train_loss": 1.4942118644714355
},
{
"fold_idx": 1,
"best_val_loss": 8.418675899505615,
"epochs_trained": 20,
"final_train_loss": 1.4902493238449097
},
{
"fold_idx": 2,
"best_val_loss": 3.5122547830854143,
"epochs_trained": 25,
"final_train_loss": 1.7609570423762004
},
{
"fold_idx": 3,
"best_val_loss": 3.165306806564331,
"epochs_trained": 21,
"final_train_loss": 2.0073827385902403
},
{
"fold_idx": 4,
"best_val_loss": 2.996154228846232,
"epochs_trained": 18,
"final_train_loss": 1.9732873006300493
}
],
"summary": {
"val_loss_mean": 4.772013900393532,
"val_loss_std": 2.0790222989111475
},
"config": {
"d_model": 256,
"num_heads": 8,
"n_attn_layers": 4,
"fusion_strategy": "attention",
"head_hidden_dim": 128,
"dropout": 0.1,
"use_mpnn": true,
"lr": 0.0001,
"weight_decay": 1e-05,
"batch_size": 32,
"epochs": 100,
"patience": 15,
"init_from_pretrain": null,
"freeze_backbone": false
}
}

View File

@ -0,0 +1,510 @@
{
"train": [
{
"loss": 17.65872812271118,
"loss_size": 12.601411867141724,
"loss_pdi": 1.3666706204414367,
"loss_ee": 1.0830313920974732,
"loss_delivery": 0.5962779104709626,
"loss_biodist": 1.3918164849281311,
"loss_toxic": 0.6195200622081757
},
{
"loss": 5.925264883041382,
"loss_size": 1.8580878481268883,
"loss_pdi": 1.1011681258678436,
"loss_ee": 0.971046245098114,
"loss_delivery": 0.5075224950909615,
"loss_biodist": 1.1051940202713013,
"loss_toxic": 0.38224617540836336
},
{
"loss": 3.4781792640686033,
"loss_size": 0.23610344529151917,
"loss_pdi": 0.8137399554252625,
"loss_ee": 0.9135127127170563,
"loss_delivery": 0.4596045270562172,
"loss_biodist": 0.8695587992668152,
"loss_toxic": 0.18565986081957817
},
{
"loss": 2.9488561868667604,
"loss_size": 0.23130029290914536,
"loss_pdi": 0.644479614496231,
"loss_ee": 0.8721524059772492,
"loss_delivery": 0.4146773874759674,
"loss_biodist": 0.646893310546875,
"loss_toxic": 0.13935319259762763
},
{
"loss": 2.6432241678237913,
"loss_size": 0.16843259893357754,
"loss_pdi": 0.5857123643159866,
"loss_ee": 0.8315786123275757,
"loss_delivery": 0.4049036353826523,
"loss_biodist": 0.5410242855548859,
"loss_toxic": 0.11157271154224872
},
{
"loss": 2.461507487297058,
"loss_size": 0.18602822050452233,
"loss_pdi": 0.5872043997049332,
"loss_ee": 0.8179578661918641,
"loss_delivery": 0.32779163047671317,
"loss_biodist": 0.45097417533397677,
"loss_toxic": 0.09155115596950054
},
{
"loss": 2.3792370796203612,
"loss_size": 0.2090120367705822,
"loss_pdi": 0.5358257800340652,
"loss_ee": 0.8088949501514435,
"loss_delivery": 0.3434994474053383,
"loss_biodist": 0.40993946194648745,
"loss_toxic": 0.07206540685147048
},
{
"loss": 2.207099366188049,
"loss_size": 0.1589151345193386,
"loss_pdi": 0.5283154606819153,
"loss_ee": 0.7723551869392395,
"loss_delivery": 0.35645291954278946,
"loss_biodist": 0.3404483631253242,
"loss_toxic": 0.05061229532584548
},
{
"loss": 2.1428971529006957,
"loss_size": 0.19335013553500174,
"loss_pdi": 0.5021985083818435,
"loss_ee": 0.7642539083957672,
"loss_delivery": 0.31821031123399734,
"loss_biodist": 0.32588216066360476,
"loss_toxic": 0.03900211993604898
},
{
"loss": 1.9874909400939942,
"loss_size": 0.1736245721578598,
"loss_pdi": 0.46206980347633364,
"loss_ee": 0.7373365700244904,
"loss_delivery": 0.29703493416309357,
"loss_biodist": 0.2863417714834213,
"loss_toxic": 0.031083252932876348
},
{
"loss": 1.9297520160675048,
"loss_size": 0.1635374441742897,
"loss_pdi": 0.4737923800945282,
"loss_ee": 0.7171129584312439,
"loss_delivery": 0.28808903992176055,
"loss_biodist": 0.25874830335378646,
"loss_toxic": 0.028471904620528222
},
{
"loss": 1.8647576332092286,
"loss_size": 0.14790172204375268,
"loss_pdi": 0.4427785277366638,
"loss_ee": 0.7089932143688202,
"loss_delivery": 0.30143058970570563,
"loss_biodist": 0.24234647750854493,
"loss_toxic": 0.021307120053097605
},
{
"loss": 1.7996623039245605,
"loss_size": 0.1429538145661354,
"loss_pdi": 0.45114057660102846,
"loss_ee": 0.681770408153534,
"loss_delivery": 0.2735618159174919,
"loss_biodist": 0.2338838443160057,
"loss_toxic": 0.01635184111073613
},
{
"loss": 1.7303769707679748,
"loss_size": 0.13725369721651076,
"loss_pdi": 0.43492600619792937,
"loss_ee": 0.6648448914289474,
"loss_delivery": 0.2714417055249214,
"loss_biodist": 0.20898159295320512,
"loss_toxic": 0.012929048202931882
},
{
"loss": 1.702065145969391,
"loss_size": 0.1783118523657322,
"loss_pdi": 0.4118753671646118,
"loss_ee": 0.640222480893135,
"loss_delivery": 0.2610591858625412,
"loss_biodist": 0.20058825612068176,
"loss_toxic": 0.01000797227025032
},
{
"loss": 1.6243244886398316,
"loss_size": 0.1371393844485283,
"loss_pdi": 0.3978125751018524,
"loss_ee": 0.6315451622009277,
"loss_delivery": 0.2618463449180126,
"loss_biodist": 0.18574777096509934,
"loss_toxic": 0.010233237966895103
},
{
"loss": 1.645119547843933,
"loss_size": 0.13622624576091766,
"loss_pdi": 0.4013118803501129,
"loss_ee": 0.639850401878357,
"loss_delivery": 0.2615354858338833,
"loss_biodist": 0.19717498123645782,
"loss_toxic": 0.009020529384724797
},
{
"loss": 1.5792422771453858,
"loss_size": 0.12063037976622581,
"loss_pdi": 0.40477685928344725,
"loss_ee": 0.6168571084737777,
"loss_delivery": 0.23877703920006751,
"loss_biodist": 0.1887524366378784,
"loss_toxic": 0.009448455832898616
},
{
"loss": 1.5701380014419555,
"loss_size": 0.12370488420128822,
"loss_pdi": 0.3944096490740776,
"loss_ee": 0.6204680263996124,
"loss_delivery": 0.2499392546713352,
"loss_biodist": 0.1741167649626732,
"loss_toxic": 0.00749938020016998
},
{
"loss": 1.5445807576179504,
"loss_size": 0.12085893377661705,
"loss_pdi": 0.4022176057100296,
"loss_ee": 0.6029386401176453,
"loss_delivery": 0.2460342638194561,
"loss_biodist": 0.16601160615682603,
"loss_toxic": 0.006519717467017472
},
{
"loss": 1.4764926195144654,
"loss_size": 0.11393929794430732,
"loss_pdi": 0.3614879995584488,
"loss_ee": 0.5874974340200424,
"loss_delivery": 0.2382828861474991,
"loss_biodist": 0.168075630068779,
"loss_toxic": 0.00720936032012105
},
{
"loss": 1.4663256525993347,
"loss_size": 0.10480817258358002,
"loss_pdi": 0.3699364930391312,
"loss_ee": 0.591068571805954,
"loss_delivery": 0.23481545299291612,
"loss_biodist": 0.1582734301686287,
"loss_toxic": 0.007423530006781221
},
{
"loss": 1.4797919273376465,
"loss_size": 0.11906521767377853,
"loss_pdi": 0.3831163257360458,
"loss_ee": 0.5810098886489868,
"loss_delivery": 0.22465722858905793,
"loss_biodist": 0.16469249799847602,
"loss_toxic": 0.007250743336044252
},
{
"loss": 1.4942118644714355,
"loss_size": 0.11249525547027588,
"loss_pdi": 0.3718418627977371,
"loss_ee": 0.5973137259483338,
"loss_delivery": 0.23963096588850022,
"loss_biodist": 0.16598810032010078,
"loss_toxic": 0.006941930414177478
}
],
"val": [
{
"loss": 13.683866500854492,
"loss_size": 5.657964706420898,
"loss_pdi": 1.1590962409973145,
"loss_ee": 1.0155898332595825,
"loss_delivery": 4.1429033279418945,
"loss_biodist": 1.128843069076538,
"loss_toxic": 0.579468846321106,
"acc_pdi": 0.7407407407407407,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
},
{
"loss": 7.161172866821289,
"loss_size": 0.1799931526184082,
"loss_pdi": 0.8303115963935852,
"loss_ee": 0.942605197429657,
"loss_delivery": 3.986294984817505,
"loss_biodist": 1.022797703742981,
"loss_toxic": 0.19917015731334686,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.554836273193359,
"loss_size": 0.04609166830778122,
"loss_pdi": 0.4924769997596741,
"loss_ee": 0.965587317943573,
"loss_delivery": 3.978637933731079,
"loss_biodist": 1.0135102272033691,
"loss_toxic": 0.05853228643536568,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.843129634857178,
"loss_size": 0.07650057226419449,
"loss_pdi": 0.43551138043403625,
"loss_ee": 0.9353340864181519,
"loss_delivery": 4.557775974273682,
"loss_biodist": 0.7909315228462219,
"loss_toxic": 0.047075945883989334,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.711758613586426,
"loss_size": 0.04316325858235359,
"loss_pdi": 0.41873815655708313,
"loss_ee": 1.0096691846847534,
"loss_delivery": 4.517927169799805,
"loss_biodist": 0.6788683533668518,
"loss_toxic": 0.04339226707816124,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.905030250549316,
"loss_size": 0.045318666845560074,
"loss_pdi": 0.38593801856040955,
"loss_ee": 1.0019593238830566,
"loss_delivery": 4.807835578918457,
"loss_biodist": 0.6247215867042542,
"loss_toxic": 0.039257097989320755,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.417820930480957,
"loss_size": 0.05034356936812401,
"loss_pdi": 0.4149726331233978,
"loss_ee": 0.9869357943534851,
"loss_delivery": 4.405001640319824,
"loss_biodist": 0.533240556716919,
"loss_toxic": 0.02732720412313938,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.608631610870361,
"loss_size": 0.05222579464316368,
"loss_pdi": 0.4375711679458618,
"loss_ee": 1.0041171312332153,
"loss_delivery": 4.578192234039307,
"loss_biodist": 0.5125234723091125,
"loss_toxic": 0.02400212176144123,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 5.7676777839660645,
"loss_size": 0.09589201211929321,
"loss_pdi": 0.3261733949184418,
"loss_ee": 0.9482788443565369,
"loss_delivery": 3.856112003326416,
"loss_biodist": 0.5298716425895691,
"loss_toxic": 0.011350298300385475,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.920990943908691,
"loss_size": 0.05388057231903076,
"loss_pdi": 0.39705148339271545,
"loss_ee": 0.990842878818512,
"loss_delivery": 5.025243282318115,
"loss_biodist": 0.4346938133239746,
"loss_toxic": 0.019278930500149727,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 5.798760890960693,
"loss_size": 0.09857960045337677,
"loss_pdi": 0.33329641819000244,
"loss_ee": 0.9614524245262146,
"loss_delivery": 4.000489711761475,
"loss_biodist": 0.39874210953712463,
"loss_toxic": 0.006200834643095732,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.575327396392822,
"loss_size": 0.054699063301086426,
"loss_pdi": 0.33702051639556885,
"loss_ee": 0.9436452388763428,
"loss_delivery": 4.817119121551514,
"loss_biodist": 0.41582298278808594,
"loss_toxic": 0.007020703982561827,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 5.989306449890137,
"loss_size": 0.09009546041488647,
"loss_pdi": 0.3044246733188629,
"loss_ee": 1.0130207538604736,
"loss_delivery": 4.140576362609863,
"loss_biodist": 0.4378862977027893,
"loss_toxic": 0.0033026484306901693,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.383339881896973,
"loss_size": 0.1530081033706665,
"loss_pdi": 0.29700207710266113,
"loss_ee": 0.9943283796310425,
"loss_delivery": 4.564785480499268,
"loss_biodist": 0.37085360288619995,
"loss_toxic": 0.003362649120390415,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.233416557312012,
"loss_size": 0.1473817676305771,
"loss_pdi": 0.2754640281200409,
"loss_ee": 0.9803684949874878,
"loss_delivery": 4.443488597869873,
"loss_biodist": 0.38424524664878845,
"loss_toxic": 0.0024682653602212667,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.094257354736328,
"loss_size": 0.10127364844083786,
"loss_pdi": 0.2960923910140991,
"loss_ee": 1.0121080875396729,
"loss_delivery": 4.2689008712768555,
"loss_biodist": 0.4132467210292816,
"loss_toxic": 0.0026356647722423077,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.0315470695495605,
"loss_size": 0.13236114382743835,
"loss_pdi": 0.29554903507232666,
"loss_ee": 0.9912998080253601,
"loss_delivery": 4.2240777015686035,
"loss_biodist": 0.3861069977283478,
"loss_toxic": 0.002152827335521579,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.13291597366333,
"loss_size": 0.10603927820920944,
"loss_pdi": 0.30880627036094666,
"loss_ee": 1.0417256355285645,
"loss_delivery": 4.337818622589111,
"loss_biodist": 0.33598482608795166,
"loss_toxic": 0.002541647758334875,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 5.918347358703613,
"loss_size": 0.11423231661319733,
"loss_pdi": 0.2779754102230072,
"loss_ee": 1.023812174797058,
"loss_delivery": 4.137387275695801,
"loss_biodist": 0.36283549666404724,
"loss_toxic": 0.002104171784594655,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.354115962982178,
"loss_size": 0.1212025135755539,
"loss_pdi": 0.2848753333091736,
"loss_ee": 1.031553030014038,
"loss_delivery": 4.554471969604492,
"loss_biodist": 0.3598195016384125,
"loss_toxic": 0.0021939175203442574,
"acc_pdi": 0.8518518518518519,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 5.881389141082764,
"loss_size": 0.1030399352312088,
"loss_pdi": 0.2791188657283783,
"loss_ee": 1.0205037593841553,
"loss_delivery": 4.111578464508057,
"loss_biodist": 0.3657107949256897,
"loss_toxic": 0.0014369667042046785,
"acc_pdi": 0.8518518518518519,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.1028852462768555,
"loss_size": 0.10241233557462692,
"loss_pdi": 0.300007700920105,
"loss_ee": 1.0756882429122925,
"loss_delivery": 4.258440971374512,
"loss_biodist": 0.3646480441093445,
"loss_toxic": 0.0016881312476471066,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.128824234008789,
"loss_size": 0.1437627077102661,
"loss_pdi": 0.29325851798057556,
"loss_ee": 1.0818182229995728,
"loss_delivery": 4.236568450927734,
"loss_biodist": 0.3719424605369568,
"loss_toxic": 0.001474093529395759,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.055476188659668,
"loss_size": 0.13312266767024994,
"loss_pdi": 0.28571468591690063,
"loss_ee": 1.066524624824524,
"loss_delivery": 4.214193820953369,
"loss_biodist": 0.35442692041397095,
"loss_toxic": 0.0014939504908397794,
"acc_pdi": 0.8518518518518519,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
}
]
}

Binary file not shown.

View File

@ -0,0 +1,426 @@
{
"train": [
{
"loss": 21.963233947753906,
"loss_size": 16.82633171081543,
"loss_pdi": 1.2230936765670777,
"loss_ee": 1.0703922033309936,
"loss_delivery": 1.0690569162368775,
"loss_biodist": 1.1534382343292235,
"loss_toxic": 0.6209211587905884
},
{
"loss": 13.145495796203614,
"loss_size": 8.676862716674805,
"loss_pdi": 1.0655134558677672,
"loss_ee": 0.8999906063079834,
"loss_delivery": 0.8303895950317383,
"loss_biodist": 1.122723388671875,
"loss_toxic": 0.5500160694122315
},
{
"loss": 7.351448345184326,
"loss_size": 3.415665292739868,
"loss_pdi": 0.8565655469894409,
"loss_ee": 0.7837236523628235,
"loss_delivery": 0.8804788589477539,
"loss_biodist": 1.011645209789276,
"loss_toxic": 0.40336963534355164
},
{
"loss": 4.39948205947876,
"loss_size": 0.9713698267936707,
"loss_pdi": 0.6989291191101075,
"loss_ee": 0.6805540442466735,
"loss_delivery": 0.7624839186668396,
"loss_biodist": 0.9798830866813659,
"loss_toxic": 0.3062621414661407
},
{
"loss": 3.375754451751709,
"loss_size": 0.24608666747808455,
"loss_pdi": 0.5557448148727417,
"loss_ee": 0.6684133768081665,
"loss_delivery": 0.7611681580543518,
"loss_biodist": 0.919653308391571,
"loss_toxic": 0.22468801140785216
},
{
"loss": 2.9307605743408205,
"loss_size": 0.1106911577284336,
"loss_pdi": 0.5004462003707886,
"loss_ee": 0.6227471172809601,
"loss_delivery": 0.6758030593395233,
"loss_biodist": 0.8190896153450012,
"loss_toxic": 0.20198351740837098
},
{
"loss": 2.731675052642822,
"loss_size": 0.13740637749433518,
"loss_pdi": 0.4836215674877167,
"loss_ee": 0.5896897256374359,
"loss_delivery": 0.5866121172904968,
"loss_biodist": 0.7556124567985535,
"loss_toxic": 0.17873288169503213
},
{
"loss": 2.4887039184570314,
"loss_size": 0.12009606957435608,
"loss_pdi": 0.4361336886882782,
"loss_ee": 0.597134268283844,
"loss_delivery": 0.5648026138544082,
"loss_biodist": 0.6326960444450378,
"loss_toxic": 0.13784122765064238
},
{
"loss": 2.1680586099624635,
"loss_size": 0.12401954531669616,
"loss_pdi": 0.40216060280799865,
"loss_ee": 0.5528951227664948,
"loss_delivery": 0.42899617552757263,
"loss_biodist": 0.5442585527896882,
"loss_toxic": 0.1157285787165165
},
{
"loss": 2.1059993267059327,
"loss_size": 0.13299092650413513,
"loss_pdi": 0.38143277168273926,
"loss_ee": 0.5274551689624787,
"loss_delivery": 0.47739412933588027,
"loss_biodist": 0.4953398108482361,
"loss_toxic": 0.0913865402340889
},
{
"loss": 1.9570286750793457,
"loss_size": 0.1426382303237915,
"loss_pdi": 0.38325140476226804,
"loss_ee": 0.49524895548820497,
"loss_delivery": 0.42715947031974794,
"loss_biodist": 0.4287752747535706,
"loss_toxic": 0.07995530962944031
},
{
"loss": 1.8469573497772216,
"loss_size": 0.14165955781936646,
"loss_pdi": 0.36685559153556824,
"loss_ee": 0.4988661766052246,
"loss_delivery": 0.36661114990711213,
"loss_biodist": 0.39747334718704225,
"loss_toxic": 0.07549156174063683
},
{
"loss": 1.6980855226516725,
"loss_size": 0.11332993358373641,
"loss_pdi": 0.350938493013382,
"loss_ee": 0.47553136944770813,
"loss_delivery": 0.30049399137496946,
"loss_biodist": 0.3953311860561371,
"loss_toxic": 0.062460555136203764
},
{
"loss": 1.743706512451172,
"loss_size": 0.12467859983444214,
"loss_pdi": 0.3706244468688965,
"loss_ee": 0.4802402436733246,
"loss_delivery": 0.36484516113996507,
"loss_biodist": 0.3557030588388443,
"loss_toxic": 0.04761496149003506
},
{
"loss": 1.7470735549926757,
"loss_size": 0.10215002745389938,
"loss_pdi": 0.3553147315979004,
"loss_ee": 0.4548905730247498,
"loss_delivery": 0.4480485826730728,
"loss_biodist": 0.3265932142734528,
"loss_toxic": 0.06007647253572941
},
{
"loss": 1.7687433004379272,
"loss_size": 0.10528398901224137,
"loss_pdi": 0.35497177839279176,
"loss_ee": 0.4946293234825134,
"loss_delivery": 0.44853600263595583,
"loss_biodist": 0.3113987982273102,
"loss_toxic": 0.053923492506146434
},
{
"loss": 1.573294997215271,
"loss_size": 0.11145550012588501,
"loss_pdi": 0.33941014409065245,
"loss_ee": 0.42823529839515684,
"loss_delivery": 0.34292849004268644,
"loss_biodist": 0.3095307767391205,
"loss_toxic": 0.04173475466668606
},
{
"loss": 1.482050108909607,
"loss_size": 0.13211917281150817,
"loss_pdi": 0.31831381320953367,
"loss_ee": 0.4258797198534012,
"loss_delivery": 0.26612227857112886,
"loss_biodist": 0.30344046354293824,
"loss_toxic": 0.03617466017603874
},
{
"loss": 1.5079625368118286,
"loss_size": 0.1129397764801979,
"loss_pdi": 0.3118207275867462,
"loss_ee": 0.4255594819784164,
"loss_delivery": 0.30544502288103104,
"loss_biodist": 0.31328115463256834,
"loss_toxic": 0.03891638442873955
},
{
"loss": 1.4902493238449097,
"loss_size": 0.09879767149686813,
"loss_pdi": 0.333440762758255,
"loss_ee": 0.430321592092514,
"loss_delivery": 0.3070627197623253,
"loss_biodist": 0.28984564244747163,
"loss_toxic": 0.030780918896198273
}
],
"val": [
{
"loss": 24.328961690266926,
"loss_size": 15.672358830769857,
"loss_pdi": 1.268057902654012,
"loss_ee": 1.0569811463356018,
"loss_delivery": 4.617272272706032,
"loss_biodist": 1.0806464751561482,
"loss_toxic": 0.633646289507548,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.5894736842105263,
"acc_toxic": 0.8939393939393939
},
{
"loss": 17.03301429748535,
"loss_size": 8.649629751841227,
"loss_pdi": 1.165820797284444,
"loss_ee": 0.9437925020853678,
"loss_delivery": 4.629274984200795,
"loss_biodist": 1.0683060089747112,
"loss_toxic": 0.5761909882227579,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 11.635572751363119,
"loss_size": 3.504341204961141,
"loss_pdi": 1.0855141083399455,
"loss_ee": 0.8674407601356506,
"loss_delivery": 4.705501407384872,
"loss_biodist": 1.0376905004183452,
"loss_toxic": 0.43508487939834595,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.058362166086832,
"loss_size": 1.0461570421854656,
"loss_pdi": 1.070031762123108,
"loss_ee": 0.8463932275772095,
"loss_delivery": 4.791346887747447,
"loss_biodist": 0.9781110286712646,
"loss_toxic": 0.32632239659627277,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.418675899505615,
"loss_size": 0.3764288102587064,
"loss_pdi": 1.0916812817255657,
"loss_ee": 0.8714254101117452,
"loss_delivery": 4.8696667949358625,
"loss_biodist": 0.9307892719904581,
"loss_toxic": 0.278684730331103,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.51748021443685,
"loss_size": 0.33909208327531815,
"loss_pdi": 1.103804111480713,
"loss_ee": 0.8707688599824905,
"loss_delivery": 5.0624091029167175,
"loss_biodist": 0.8743396997451782,
"loss_toxic": 0.2670666699608167,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.701509237289429,
"loss_size": 0.38883806640903157,
"loss_pdi": 1.0901564558347066,
"loss_ee": 0.8219001442193985,
"loss_delivery": 5.329233412941297,
"loss_biodist": 0.808117667833964,
"loss_toxic": 0.2632630293567975,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.602253516515097,
"loss_size": 0.399209912866354,
"loss_pdi": 1.035650501648585,
"loss_ee": 0.8119546920061111,
"loss_delivery": 5.297288862367471,
"loss_biodist": 0.8003136416276296,
"loss_toxic": 0.2578362462421258,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.610430796941122,
"loss_size": 0.3884888291358948,
"loss_pdi": 0.9680223266283671,
"loss_ee": 0.8063104202349981,
"loss_delivery": 5.504999443888664,
"loss_biodist": 0.7153328458468119,
"loss_toxic": 0.22727691816786924,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.894750118255615,
"loss_size": 0.38015256201227504,
"loss_pdi": 0.9849910040696462,
"loss_ee": 0.8192636320988337,
"loss_delivery": 5.8433875640233355,
"loss_biodist": 0.6525928874810537,
"loss_toxic": 0.21436312049627304,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.684672435124716,
"loss_size": 0.39142270882924396,
"loss_pdi": 0.9926454623540243,
"loss_ee": 0.8487897912661234,
"loss_delivery": 5.675399616360664,
"loss_biodist": 0.5763055086135864,
"loss_toxic": 0.2001086367915074,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.468807538350424,
"loss_size": 0.37035099665323895,
"loss_pdi": 0.9933059811592102,
"loss_ee": 0.8365495651960373,
"loss_delivery": 5.39086152613163,
"loss_biodist": 0.6555034021536509,
"loss_toxic": 0.22223659542699656,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.48760732014974,
"loss_size": 0.3547621878484885,
"loss_pdi": 1.008083571990331,
"loss_ee": 0.8507340376575788,
"loss_delivery": 5.329072058200836,
"loss_biodist": 0.7051869928836823,
"loss_toxic": 0.23976873668531576,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.534255782763163,
"loss_size": 0.35214799270033836,
"loss_pdi": 1.0083338419596355,
"loss_ee": 0.8703259030977885,
"loss_delivery": 5.4809657235940294,
"loss_biodist": 0.6066243648529053,
"loss_toxic": 0.21585797673712173,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6842105263157895,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.59092911084493,
"loss_size": 0.3520332872867584,
"loss_pdi": 0.9944024880727133,
"loss_ee": 0.8839219162861506,
"loss_delivery": 5.593439628680547,
"loss_biodist": 0.562449519832929,
"loss_toxic": 0.20468231476843357,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.581690510114035,
"loss_size": 0.3468632685641448,
"loss_pdi": 1.0153752664724986,
"loss_ee": 0.884696863591671,
"loss_delivery": 5.548932209610939,
"loss_biodist": 0.576594889163971,
"loss_toxic": 0.20922777770708004,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.60028068224589,
"loss_size": 0.34553587809205055,
"loss_pdi": 1.0314316948254902,
"loss_ee": 0.8696443388859431,
"loss_delivery": 5.513105024894078,
"loss_biodist": 0.6187789390484492,
"loss_toxic": 0.2217849005634586,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6947368421052632,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.721842130025228,
"loss_size": 0.3432792164385319,
"loss_pdi": 1.044082870086034,
"loss_ee": 0.8888355021675428,
"loss_delivery": 5.590167284011841,
"loss_biodist": 0.6300752113262812,
"loss_toxic": 0.22540184513976178,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.821967244148254,
"loss_size": 0.3423520748813947,
"loss_pdi": 1.0627215206623077,
"loss_ee": 0.9012102037668228,
"loss_delivery": 5.6443866689999895,
"loss_biodist": 0.6428664823373159,
"loss_toxic": 0.2284308553983768,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6631578947368421,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.798149506251017,
"loss_size": 0.3439513569076856,
"loss_pdi": 1.074403668443362,
"loss_ee": 0.9037297517061234,
"loss_delivery": 5.62445667386055,
"loss_biodist": 0.6279164751370748,
"loss_toxic": 0.22369086369872093,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6631578947368421,
"acc_toxic": 0.8939393939393939
}
]
}

Binary file not shown.

View File

@ -0,0 +1,531 @@
{
"train": [
{
"loss": 22.972569465637207,
"loss_size": 17.22947899500529,
"loss_pdi": 1.3510672052701314,
"loss_ee": 1.0506827433904011,
"loss_delivery": 1.5721548050642014,
"loss_biodist": 1.1304322481155396,
"loss_toxic": 0.6387530366579691
},
{
"loss": 12.718681335449219,
"loss_size": 7.335077285766602,
"loss_pdi": 1.198062241077423,
"loss_ee": 0.97108127673467,
"loss_delivery": 1.639556477467219,
"loss_biodist": 1.0746847093105316,
"loss_toxic": 0.50021959344546
},
{
"loss": 6.867454210917155,
"loss_size": 2.3153140544891357,
"loss_pdi": 1.0159071584542592,
"loss_ee": 0.853508859872818,
"loss_delivery": 1.2862873176733653,
"loss_biodist": 1.0416639745235443,
"loss_toxic": 0.35477257271607715
},
{
"loss": 4.856432318687439,
"loss_size": 0.5409951706727346,
"loss_pdi": 0.8652523259321848,
"loss_ee": 0.7771940131982168,
"loss_delivery": 1.413562481602033,
"loss_biodist": 0.9977987806002299,
"loss_toxic": 0.26162934054931003
},
{
"loss": 4.253215591112773,
"loss_size": 0.2641367167234421,
"loss_pdi": 0.739859402179718,
"loss_ee": 0.7256686190764109,
"loss_delivery": 1.4241955528656642,
"loss_biodist": 0.8935903211434683,
"loss_toxic": 0.2057649294535319
},
{
"loss": 3.8961705764134726,
"loss_size": 0.2962125514944394,
"loss_pdi": 0.682400623957316,
"loss_ee": 0.6820215880870819,
"loss_delivery": 1.2787245536843936,
"loss_biodist": 0.7812575101852417,
"loss_toxic": 0.1755537080268065
},
{
"loss": 3.4790991942087808,
"loss_size": 0.3047281603018443,
"loss_pdi": 0.6409291823705038,
"loss_ee": 0.6178905169169108,
"loss_delivery": 1.1121559316913288,
"loss_biodist": 0.6434484819571177,
"loss_toxic": 0.15994682783881822
},
{
"loss": 3.2075613339742026,
"loss_size": 0.3421506683031718,
"loss_pdi": 0.5879766543706259,
"loss_ee": 0.5811398377021154,
"loss_delivery": 1.0462109719713528,
"loss_biodist": 0.520307645201683,
"loss_toxic": 0.12977550799647966
},
{
"loss": 2.861353278160095,
"loss_size": 0.2742840300003688,
"loss_pdi": 0.5437282969554266,
"loss_ee": 0.5531725088755289,
"loss_delivery": 0.9213679246604443,
"loss_biodist": 0.4499489863713582,
"loss_toxic": 0.11885150956610839
},
{
"loss": 2.6909215847651162,
"loss_size": 0.23881135260065398,
"loss_pdi": 0.5229279547929764,
"loss_ee": 0.5285524874925613,
"loss_delivery": 0.8911051253477732,
"loss_biodist": 0.4015616128842036,
"loss_toxic": 0.10796305599311988
},
{
"loss": 2.5927247206370034,
"loss_size": 0.27356760079662007,
"loss_pdi": 0.5166990955670675,
"loss_ee": 0.5059170673290888,
"loss_delivery": 0.8377179056406021,
"loss_biodist": 0.3519642899433772,
"loss_toxic": 0.10685871541500092
},
{
"loss": 2.3971973856290183,
"loss_size": 0.2688147674004237,
"loss_pdi": 0.4851151605447133,
"loss_ee": 0.47870688637097675,
"loss_delivery": 0.7584750155607859,
"loss_biodist": 0.3166690344611804,
"loss_toxic": 0.08941652067005634
},
{
"loss": 2.2271180947621665,
"loss_size": 0.2559296215573947,
"loss_pdi": 0.467803418636322,
"loss_ee": 0.4819647620121638,
"loss_delivery": 0.6487737223505974,
"loss_biodist": 0.2930952211221059,
"loss_toxic": 0.079551310899357
},
{
"loss": 2.1467134952545166,
"loss_size": 0.2658323546250661,
"loss_pdi": 0.47287177046140033,
"loss_ee": 0.4580538024504979,
"loss_delivery": 0.6110207016269366,
"loss_biodist": 0.26590356479088467,
"loss_toxic": 0.07303123424450557
},
{
"loss": 2.0699684421221414,
"loss_size": 0.23655260602633157,
"loss_pdi": 0.46446068088213605,
"loss_ee": 0.43884341915448505,
"loss_delivery": 0.5945644030968348,
"loss_biodist": 0.26856863250335056,
"loss_toxic": 0.06697871504972379
},
{
"loss": 2.012367367744446,
"loss_size": 0.20358355715870857,
"loss_pdi": 0.44864421089490253,
"loss_ee": 0.4260970900456111,
"loss_delivery": 0.6111055202782154,
"loss_biodist": 0.24829111248254776,
"loss_toxic": 0.07464585608492295
},
{
"loss": 1.9354575673739116,
"loss_size": 0.19155597686767578,
"loss_pdi": 0.43001438677310944,
"loss_ee": 0.4029633104801178,
"loss_delivery": 0.5866967861851057,
"loss_biodist": 0.26284457246462506,
"loss_toxic": 0.06138256782044967
},
{
"loss": 1.9248821139335632,
"loss_size": 0.19836385796467462,
"loss_pdi": 0.43165912727514905,
"loss_ee": 0.4223821411530177,
"loss_delivery": 0.5774712382505337,
"loss_biodist": 0.23008103668689728,
"loss_toxic": 0.06492467441906531
},
{
"loss": 1.7986130317052205,
"loss_size": 0.1977602814634641,
"loss_pdi": 0.4213625093301137,
"loss_ee": 0.3969506522019704,
"loss_delivery": 0.4972396679222584,
"loss_biodist": 0.22815552850564322,
"loss_toxic": 0.05714430411656698
},
{
"loss": 1.8008437156677246,
"loss_size": 0.20143492271502814,
"loss_pdi": 0.4257240394751231,
"loss_ee": 0.3939937750498454,
"loss_delivery": 0.4996156108876069,
"loss_biodist": 0.22945881386597952,
"loss_toxic": 0.05061656702309847
},
{
"loss": 1.8123606244723003,
"loss_size": 0.23175274084011713,
"loss_pdi": 0.41065867245197296,
"loss_ee": 0.38645289838314056,
"loss_delivery": 0.5105274474869171,
"loss_biodist": 0.22759289046128592,
"loss_toxic": 0.04537593169758717
},
{
"loss": 1.85766206185023,
"loss_size": 0.2237167110045751,
"loss_pdi": 0.4198872745037079,
"loss_ee": 0.39036936064561206,
"loss_delivery": 0.5356200908621153,
"loss_biodist": 0.2327907457947731,
"loss_toxic": 0.055277835965777435
},
{
"loss": 1.7299150824546814,
"loss_size": 0.16866947089632353,
"loss_pdi": 0.40182044357061386,
"loss_ee": 0.37123599648475647,
"loss_delivery": 0.5183743331581354,
"loss_biodist": 0.22459317495425543,
"loss_toxic": 0.04522162117063999
},
{
"loss": 1.8115381598472595,
"loss_size": 0.21021889025966325,
"loss_pdi": 0.3938516428073247,
"loss_ee": 0.3856282929579417,
"loss_delivery": 0.5463737193495035,
"loss_biodist": 0.22467835744222006,
"loss_toxic": 0.05078731415172418
},
{
"loss": 1.7609570423762004,
"loss_size": 0.20672637100021043,
"loss_pdi": 0.3868243644634883,
"loss_ee": 0.3775654385487239,
"loss_delivery": 0.5160359914104143,
"loss_biodist": 0.22730938345193863,
"loss_toxic": 0.04649555139864484
}
],
"val": [
{
"loss": 19.896042142595565,
"loss_size": 14.947636876787458,
"loss_pdi": 1.3514722074781145,
"loss_ee": 1.0372784308024816,
"loss_delivery": 0.5157596128327506,
"loss_biodist": 1.3665738276072912,
"loss_toxic": 0.6773212381771633,
"acc_pdi": 0.22564102564102564,
"acc_ee": 0.4512820512820513,
"acc_toxic": 0.7073170731707317
},
{
"loss": 10.277108192443848,
"loss_size": 5.728530270712716,
"loss_pdi": 1.2047701733452933,
"loss_ee": 1.013599353177207,
"loss_delivery": 0.5329383058207375,
"loss_biodist": 1.3288453817367554,
"loss_toxic": 0.4684244394302368,
"acc_pdi": 0.358974358974359,
"acc_ee": 0.4461538461538462,
"acc_toxic": 1.0
},
{
"loss": 4.947442190987723,
"loss_size": 0.9055690084184919,
"loss_pdi": 0.9325166344642639,
"loss_ee": 1.071527932371412,
"loss_delivery": 0.5525430504764829,
"loss_biodist": 1.2514750446592058,
"loss_toxic": 0.23381062703473227,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.072668756757464,
"loss_size": 0.18621658267719404,
"loss_pdi": 0.7640595691544669,
"loss_ee": 1.2155327456338065,
"loss_delivery": 0.5523517067943301,
"loss_biodist": 1.214196733066014,
"loss_toxic": 0.14031140506267548,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.053883859089443,
"loss_size": 0.24764748875583922,
"loss_pdi": 0.6859285916600909,
"loss_ee": 1.3019903557641166,
"loss_delivery": 0.5384655041354043,
"loss_biodist": 1.191110406603132,
"loss_toxic": 0.08874151536396571,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.9850601128169467,
"loss_size": 0.18243093735405377,
"loss_pdi": 0.6606386282614299,
"loss_ee": 1.2955879313605172,
"loss_delivery": 0.6222422846726009,
"loss_biodist": 1.1586603011403764,
"loss_toxic": 0.06550001353025436,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.6691461631229947,
"loss_size": 0.18682856378810747,
"loss_pdi": 0.6505168399640492,
"loss_ee": 1.2279302733285087,
"loss_delivery": 0.5776888344969068,
"loss_biodist": 0.963392470564161,
"loss_toxic": 0.06278917672378677,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.6420267650059293,
"loss_size": 0.1952320017984935,
"loss_pdi": 0.6510439600263324,
"loss_ee": 1.1954282522201538,
"loss_delivery": 0.7644538623946053,
"loss_biodist": 0.7812093198299408,
"loss_toxic": 0.05465935756053243,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.35384615384615387,
"acc_toxic": 1.0
},
{
"loss": 3.707563672746931,
"loss_size": 0.2168926394411496,
"loss_pdi": 0.6468359615121569,
"loss_ee": 1.2361225570951189,
"loss_delivery": 0.8388645563806806,
"loss_biodist": 0.7232611009052822,
"loss_toxic": 0.04558686592749187,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.37435897435897436,
"acc_toxic": 1.0
},
{
"loss": 3.5122547830854143,
"loss_size": 0.2614529473440988,
"loss_pdi": 0.6344352598701205,
"loss_ee": 1.2337199449539185,
"loss_delivery": 0.7557642417294639,
"loss_biodist": 0.5974612619195666,
"loss_toxic": 0.029421218537858555,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.38461538461538464,
"acc_toxic": 1.0
},
{
"loss": 4.085699932915824,
"loss_size": 0.2200212436062949,
"loss_pdi": 0.6225023801837649,
"loss_ee": 1.1934180770601546,
"loss_delivery": 1.4556497505732946,
"loss_biodist": 0.5685178296906608,
"loss_toxic": 0.02559054403432778,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.8555148329053606,
"loss_size": 0.25200862703578814,
"loss_pdi": 0.6161004496472222,
"loss_ee": 1.2155306509562902,
"loss_delivery": 1.1912458751882826,
"loss_biodist": 0.5578728743961879,
"loss_toxic": 0.022756420208939483,
"acc_pdi": 0.7128205128205128,
"acc_ee": 0.4153846153846154,
"acc_toxic": 1.0
},
{
"loss": 4.016897848674229,
"loss_size": 0.24823468178510666,
"loss_pdi": 0.6098369508981705,
"loss_ee": 1.2048260143824987,
"loss_delivery": 1.3509162408964974,
"loss_biodist": 0.5822246244975499,
"loss_toxic": 0.020859350051198686,
"acc_pdi": 0.717948717948718,
"acc_ee": 0.4153846153846154,
"acc_toxic": 1.0
},
{
"loss": 3.9899745328085765,
"loss_size": 0.27859273659331457,
"loss_pdi": 0.6051683936800275,
"loss_ee": 1.1875721216201782,
"loss_delivery": 1.369072552238192,
"loss_biodist": 0.5321473862443652,
"loss_toxic": 0.017421354805784568,
"acc_pdi": 0.717948717948718,
"acc_ee": 0.40512820512820513,
"acc_toxic": 1.0
},
{
"loss": 4.112551552908761,
"loss_size": 0.23502862347023828,
"loss_pdi": 0.6127274334430695,
"loss_ee": 1.2102909428732735,
"loss_delivery": 1.5001615626471383,
"loss_biodist": 0.5411617543016162,
"loss_toxic": 0.013181165459432773,
"acc_pdi": 0.7025641025641025,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.031796966280256,
"loss_size": 0.27996559121779035,
"loss_pdi": 0.6218498295971325,
"loss_ee": 1.264663577079773,
"loss_delivery": 1.225053642477308,
"loss_biodist": 0.6304309921605247,
"loss_toxic": 0.009833223053387232,
"acc_pdi": 0.7128205128205128,
"acc_ee": 0.4153846153846154,
"acc_toxic": 1.0
},
{
"loss": 4.108100175857544,
"loss_size": 0.2646343271647181,
"loss_pdi": 0.6244613996573857,
"loss_ee": 1.2785721336092268,
"loss_delivery": 1.2817653375012534,
"loss_biodist": 0.6491539776325226,
"loss_toxic": 0.009512946475297213,
"acc_pdi": 0.7025641025641025,
"acc_ee": 0.4,
"acc_toxic": 1.0
},
{
"loss": 4.237571648188999,
"loss_size": 0.24608339369297028,
"loss_pdi": 0.6387959729347911,
"loss_ee": 1.263997495174408,
"loss_delivery": 1.3677963316440582,
"loss_biodist": 0.7113959235804421,
"loss_toxic": 0.00950257752888969,
"acc_pdi": 0.7025641025641025,
"acc_ee": 0.40512820512820513,
"acc_toxic": 1.0
},
{
"loss": 4.677373443331037,
"loss_size": 0.25485377439430784,
"loss_pdi": 0.6646602579525539,
"loss_ee": 1.2957249539239066,
"loss_delivery": 1.609152581010546,
"loss_biodist": 0.8443670613425118,
"loss_toxic": 0.008614770535911833,
"acc_pdi": 0.7025641025641025,
"acc_ee": 0.4,
"acc_toxic": 1.0
},
{
"loss": 4.661478791918073,
"loss_size": 0.22112409876925604,
"loss_pdi": 0.6840873499001775,
"loss_ee": 1.345719371523176,
"loss_delivery": 1.5750049437795366,
"loss_biodist": 0.8278610365731376,
"loss_toxic": 0.0076820029810603175,
"acc_pdi": 0.6974358974358974,
"acc_ee": 0.39487179487179486,
"acc_toxic": 1.0
},
{
"loss": 4.7805344717843195,
"loss_size": 0.23743291412081038,
"loss_pdi": 0.6911627639617238,
"loss_ee": 1.3796877009528024,
"loss_delivery": 1.6191245743206568,
"loss_biodist": 0.8459444258894239,
"loss_toxic": 0.00718201809961881,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4,
"acc_toxic": 1.0
},
{
"loss": 5.050536905016218,
"loss_size": 0.2340430138366563,
"loss_pdi": 0.7009050281984466,
"loss_ee": 1.3757387569972448,
"loss_delivery": 1.760018629687173,
"loss_biodist": 0.9723425933292934,
"loss_toxic": 0.007489000846232686,
"acc_pdi": 0.7128205128205128,
"acc_ee": 0.39487179487179486,
"acc_toxic": 1.0
},
{
"loss": 5.172625916344779,
"loss_size": 0.23549717558281763,
"loss_pdi": 0.6980976568801063,
"loss_ee": 1.357451047216143,
"loss_delivery": 1.893591480595725,
"loss_biodist": 0.9802024279321943,
"loss_toxic": 0.0077861944612647805,
"acc_pdi": 0.717948717948718,
"acc_ee": 0.38974358974358975,
"acc_toxic": 1.0
},
{
"loss": 5.048826490129743,
"loss_size": 0.2420537450483867,
"loss_pdi": 0.7013733184763363,
"loss_ee": 1.353548560823713,
"loss_delivery": 1.7931698901312692,
"loss_biodist": 0.9509889696325574,
"loss_toxic": 0.007692053714501006,
"acc_pdi": 0.7128205128205128,
"acc_ee": 0.38974358974358975,
"acc_toxic": 1.0
},
{
"loss": 4.951304980686733,
"loss_size": 0.24394649054322923,
"loss_pdi": 0.7197230202811105,
"loss_ee": 1.3789095027106149,
"loss_delivery": 1.6561105762209212,
"loss_biodist": 0.9460095167160034,
"loss_toxic": 0.006605847106714334,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.38974358974358975,
"acc_toxic": 1.0
}
]
}

Binary file not shown.

View File

@ -0,0 +1,447 @@
{
"train": [
{
"loss": 19.317236709594727,
"loss_size": 14.108779287338256,
"loss_pdi": 1.2223044037818909,
"loss_ee": 1.1201724767684937,
"loss_delivery": 1.0989456713199615,
"loss_biodist": 1.2243955612182618,
"loss_toxic": 0.5426396489143371
},
{
"loss": 7.058736562728882,
"loss_size": 2.407193088531494,
"loss_pdi": 1.0396445631980895,
"loss_ee": 1.043432891368866,
"loss_delivery": 1.183875671029091,
"loss_biodist": 1.0435532987117768,
"loss_toxic": 0.3410369783639908
},
{
"loss": 4.30960967540741,
"loss_size": 0.3954987198114395,
"loss_pdi": 0.8135693371295929,
"loss_ee": 0.9796469569206238,
"loss_delivery": 1.0894740536808967,
"loss_biodist": 0.8407172799110413,
"loss_toxic": 0.1907033085823059
},
{
"loss": 3.814995551109314,
"loss_size": 0.3033123552799225,
"loss_pdi": 0.7126355469226837,
"loss_ee": 0.9527663111686706,
"loss_delivery": 1.0963637247681617,
"loss_biodist": 0.613499540090561,
"loss_toxic": 0.13641793020069598
},
{
"loss": 3.4925455331802366,
"loss_size": 0.31617238447070123,
"loss_pdi": 0.6569374144077301,
"loss_ee": 0.9212668299674988,
"loss_delivery": 1.0250366538763047,
"loss_biodist": 0.4527473896741867,
"loss_toxic": 0.12038486637175083
},
{
"loss": 3.255272912979126,
"loss_size": 0.27022836059331895,
"loss_pdi": 0.6289663434028625,
"loss_ee": 0.9047561466693879,
"loss_delivery": 0.9742588266730309,
"loss_biodist": 0.3838836058974266,
"loss_toxic": 0.09317965283989907
},
{
"loss": 3.281973719596863,
"loss_size": 0.26578598394989966,
"loss_pdi": 0.5881433010101318,
"loss_ee": 0.8700660288333892,
"loss_delivery": 1.1519657298922539,
"loss_biodist": 0.3240764126181602,
"loss_toxic": 0.08193630240857601
},
{
"loss": 2.7810576915740968,
"loss_size": 0.2505718767642975,
"loss_pdi": 0.5593925356864929,
"loss_ee": 0.8196510195732116,
"loss_delivery": 0.7932070523500443,
"loss_biodist": 0.27653754949569703,
"loss_toxic": 0.08169759791344404
},
{
"loss": 2.644732141494751,
"loss_size": 0.27979295402765275,
"loss_pdi": 0.5457461476325989,
"loss_ee": 0.7845215618610382,
"loss_delivery": 0.722122372686863,
"loss_biodist": 0.2437703028321266,
"loss_toxic": 0.06877880096435547
},
{
"loss": 2.5743841886520387,
"loss_size": 0.21236803606152535,
"loss_pdi": 0.5281321376562118,
"loss_ee": 0.7772053182125092,
"loss_delivery": 0.7842913195490837,
"loss_biodist": 0.20931598618626596,
"loss_toxic": 0.0630713876336813
},
{
"loss": 2.493379771709442,
"loss_size": 0.2545281477272511,
"loss_pdi": 0.514763566851616,
"loss_ee": 0.7416582465171814,
"loss_delivery": 0.7315813854336739,
"loss_biodist": 0.18844463676214218,
"loss_toxic": 0.062403830140829085
},
{
"loss": 2.3714203119277952,
"loss_size": 0.21288565024733544,
"loss_pdi": 0.5149440914392471,
"loss_ee": 0.7432775914669036,
"loss_delivery": 0.6615208894014358,
"loss_biodist": 0.17799324095249175,
"loss_toxic": 0.06079882858321071
},
{
"loss": 2.3138927936553957,
"loss_size": 0.22406778559088708,
"loss_pdi": 0.5060430943965912,
"loss_ee": 0.7270951688289642,
"loss_delivery": 0.6268678307533264,
"loss_biodist": 0.17946239709854125,
"loss_toxic": 0.050356499617919326
},
{
"loss": 2.2404407501220702,
"loss_size": 0.23460092321038245,
"loss_pdi": 0.4892877459526062,
"loss_ee": 0.6908941030502319,
"loss_delivery": 0.6124202072620392,
"loss_biodist": 0.16842604279518128,
"loss_toxic": 0.044811736792325974
},
{
"loss": 2.2448294520378114,
"loss_size": 0.21119624376296997,
"loss_pdi": 0.479864901304245,
"loss_ee": 0.6906192302703857,
"loss_delivery": 0.6555144399404526,
"loss_biodist": 0.16310803219676018,
"loss_toxic": 0.044526621932163835
},
{
"loss": 2.1580574989318846,
"loss_size": 0.18697498068213464,
"loss_pdi": 0.48660930395126345,
"loss_ee": 0.6810935467481614,
"loss_delivery": 0.6051739566028118,
"loss_biodist": 0.15406969040632248,
"loss_toxic": 0.04413598729297519
},
{
"loss": 2.114891529083252,
"loss_size": 0.17799586579203605,
"loss_pdi": 0.4589719235897064,
"loss_ee": 0.6686563313007354,
"loss_delivery": 0.6179293170571327,
"loss_biodist": 0.1526280902326107,
"loss_toxic": 0.03870999766513705
},
{
"loss": 2.1680126667022703,
"loss_size": 0.18272313922643663,
"loss_pdi": 0.47693236321210863,
"loss_ee": 0.6723115026950837,
"loss_delivery": 0.6574018053710461,
"loss_biodist": 0.14306045994162558,
"loss_toxic": 0.03558342705946416
},
{
"loss": 2.0243090748786927,
"loss_size": 0.19451010078191758,
"loss_pdi": 0.46296934187412264,
"loss_ee": 0.6654580652713775,
"loss_delivery": 0.5195972554385662,
"loss_biodist": 0.1416195034980774,
"loss_toxic": 0.04015482016839087
},
{
"loss": 1.980038857460022,
"loss_size": 0.19992023780941964,
"loss_pdi": 0.4373833805322647,
"loss_ee": 0.6562270969152451,
"loss_delivery": 0.5170416861772538,
"loss_biodist": 0.13248837292194365,
"loss_toxic": 0.03697808152064681
},
{
"loss": 2.0073827385902403,
"loss_size": 0.17545675858855247,
"loss_pdi": 0.43559625148773196,
"loss_ee": 0.6394164443016053,
"loss_delivery": 0.5809337809681893,
"loss_biodist": 0.13504885137081146,
"loss_toxic": 0.040930699557065964
}
],
"val": [
{
"loss": 10.945204257965088,
"loss_size": 6.681218147277832,
"loss_pdi": 1.0216107964515686,
"loss_ee": 1.0486068725585938,
"loss_delivery": 0.4687899202108383,
"loss_biodist": 1.215298354625702,
"loss_toxic": 0.5096809715032578,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.6470588235294118,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.0456085205078125,
"loss_size": 0.3470493406057358,
"loss_pdi": 0.7843169867992401,
"loss_ee": 0.8336820006370544,
"loss_delivery": 0.42519159615039825,
"loss_biodist": 1.1528617143630981,
"loss_toxic": 0.5025068372488022,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.553251266479492,
"loss_size": 0.07565776817500591,
"loss_pdi": 0.5630811750888824,
"loss_ee": 0.6947644650936127,
"loss_delivery": 0.4020952582359314,
"loss_biodist": 1.1898333430290222,
"loss_toxic": 0.627819336950779,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.2876256704330444,
"loss_size": 0.056399866938591,
"loss_pdi": 0.579200953245163,
"loss_ee": 0.5947848558425903,
"loss_delivery": 0.4561047703027725,
"loss_biodist": 1.0357274413108826,
"loss_toxic": 0.5654077678918839,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.2625988721847534,
"loss_size": 0.11296019703149796,
"loss_pdi": 0.5352367609739304,
"loss_ee": 0.6021667718887329,
"loss_delivery": 0.5046610683202744,
"loss_biodist": 1.0080225467681885,
"loss_toxic": 0.4995514266192913,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.165306806564331,
"loss_size": 0.09038551151752472,
"loss_pdi": 0.5058617442846298,
"loss_ee": 0.6476156711578369,
"loss_delivery": 0.43027013540267944,
"loss_biodist": 1.0445645153522491,
"loss_toxic": 0.4466092698276043,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.2408690452575684,
"loss_size": 0.22243183851242065,
"loss_pdi": 0.4985402673482895,
"loss_ee": 0.571824312210083,
"loss_delivery": 0.43825456500053406,
"loss_biodist": 0.9937507510185242,
"loss_toxic": 0.5160673335194588,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.5194804668426514,
"loss_size": 0.36968255043029785,
"loss_pdi": 0.4991031885147095,
"loss_ee": 0.5797468274831772,
"loss_delivery": 0.5644859671592712,
"loss_biodist": 0.9723091125488281,
"loss_toxic": 0.5341527052223682,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.539685606956482,
"loss_size": 0.38313066959381104,
"loss_pdi": 0.528433233499527,
"loss_ee": 0.5810057669878006,
"loss_delivery": 0.44039086997509,
"loss_biodist": 1.0017918348312378,
"loss_toxic": 0.6049331650137901,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.9403724670410156,
"loss_size": 0.6225972771644592,
"loss_pdi": 0.5688649713993073,
"loss_ee": 0.6205386221408844,
"loss_delivery": 0.6095166206359863,
"loss_biodist": 0.9419751763343811,
"loss_toxic": 0.576879795640707,
"acc_pdi": 0.803921568627451,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.8653494119644165,
"loss_size": 0.6294703483581543,
"loss_pdi": 0.5615053772926331,
"loss_ee": 0.6072992980480194,
"loss_delivery": 0.47824281454086304,
"loss_biodist": 0.964938759803772,
"loss_toxic": 0.6238927394151688,
"acc_pdi": 0.8431372549019608,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.109289169311523,
"loss_size": 0.753549188375473,
"loss_pdi": 0.6232334971427917,
"loss_ee": 0.6752453744411469,
"loss_delivery": 0.4541686922311783,
"loss_biodist": 1.0001222491264343,
"loss_toxic": 0.6029700562357903,
"acc_pdi": 0.7450980392156863,
"acc_ee": 0.7254901960784313,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.084217548370361,
"loss_size": 0.6689053475856781,
"loss_pdi": 0.5947604179382324,
"loss_ee": 0.6819752305746078,
"loss_delivery": 0.5174736380577087,
"loss_biodist": 0.9870622158050537,
"loss_toxic": 0.6340407878160477,
"acc_pdi": 0.7843137254901961,
"acc_ee": 0.7647058823529411,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.762814164161682,
"loss_size": 0.5682831406593323,
"loss_pdi": 0.5777421444654465,
"loss_ee": 0.7156199663877487,
"loss_delivery": 0.44971026480197906,
"loss_biodist": 0.9156049191951752,
"loss_toxic": 0.5358536541461945,
"acc_pdi": 0.7450980392156863,
"acc_ee": 0.7058823529411765,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.147223711013794,
"loss_size": 0.6644828915596008,
"loss_pdi": 0.5911359935998917,
"loss_ee": 0.713784396648407,
"loss_delivery": 0.500703439116478,
"loss_biodist": 1.0310384333133698,
"loss_toxic": 0.6460786163806915,
"acc_pdi": 0.7647058823529411,
"acc_ee": 0.7058823529411765,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.157853841781616,
"loss_size": 0.7414849102497101,
"loss_pdi": 0.630668580532074,
"loss_ee": 0.6938402056694031,
"loss_delivery": 0.50765261054039,
"loss_biodist": 0.9891600012779236,
"loss_toxic": 0.5950475558638573,
"acc_pdi": 0.7450980392156863,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.2473719120025635,
"loss_size": 0.8058429956436157,
"loss_pdi": 0.5982940196990967,
"loss_ee": 0.7209844589233398,
"loss_delivery": 0.5253763496875763,
"loss_biodist": 0.973820835351944,
"loss_toxic": 0.6230533868074417,
"acc_pdi": 0.803921568627451,
"acc_ee": 0.6862745098039216,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.90485680103302,
"loss_size": 0.5400884747505188,
"loss_pdi": 0.5671974420547485,
"loss_ee": 0.7085212767124176,
"loss_delivery": 0.5078988373279572,
"loss_biodist": 0.9940473735332489,
"loss_toxic": 0.5871035009622574,
"acc_pdi": 0.803921568627451,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.193094968795776,
"loss_size": 0.764777421951294,
"loss_pdi": 0.5734306275844574,
"loss_ee": 0.7070393562316895,
"loss_delivery": 0.5335722267627716,
"loss_biodist": 1.0060182809829712,
"loss_toxic": 0.6082571670413017,
"acc_pdi": 0.8235294117647058,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.225732326507568,
"loss_size": 0.7807798981666565,
"loss_pdi": 0.57969930768013,
"loss_ee": 0.7046914398670197,
"loss_delivery": 0.5619150400161743,
"loss_biodist": 1.0033797025680542,
"loss_toxic": 0.5952669233083725,
"acc_pdi": 0.7843137254901961,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.122166633605957,
"loss_size": 0.727344423532486,
"loss_pdi": 0.5855642706155777,
"loss_ee": 0.7065156400203705,
"loss_delivery": 0.5201583206653595,
"loss_biodist": 0.9953437149524689,
"loss_toxic": 0.5872401967644691,
"acc_pdi": 0.803921568627451,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
}
]
}

Binary file not shown.

View File

@ -0,0 +1,384 @@
{
"train": [
{
"loss": 19.51876787705855,
"loss_size": 14.430041096427225,
"loss_pdi": 1.3163191405209629,
"loss_ee": 1.080450177192688,
"loss_delivery": 0.8366337662393396,
"loss_biodist": 1.2393495603041216,
"loss_toxic": 0.6159739873626016
},
{
"loss": 7.450059110468084,
"loss_size": 3.0973785898902197,
"loss_pdi": 1.0574429847977378,
"loss_ee": 0.972686382857236,
"loss_delivery": 0.8131051341241057,
"loss_biodist": 1.0609121160073713,
"loss_toxic": 0.4485339197245511
},
{
"loss": 4.450224074450406,
"loss_size": 0.4227097576314753,
"loss_pdi": 0.8444447679953142,
"loss_ee": 0.9470934163440358,
"loss_delivery": 1.0848771997473456,
"loss_biodist": 0.8853748061440208,
"loss_toxic": 0.26572414352135226
},
{
"loss": 3.697209119796753,
"loss_size": 0.27599087357521057,
"loss_pdi": 0.7228363969109275,
"loss_ee": 0.9080322655764493,
"loss_delivery": 0.8549216186458414,
"loss_biodist": 0.7190906730565158,
"loss_toxic": 0.21633729677308688
},
{
"loss": 3.5547448938543145,
"loss_size": 0.29429094425656577,
"loss_pdi": 0.6620089682665738,
"loss_ee": 0.8595339439131997,
"loss_delivery": 0.9978644847869873,
"loss_biodist": 0.5649206800894304,
"loss_toxic": 0.17612605541944504
},
{
"loss": 3.069189115004106,
"loss_size": 0.30328830602494156,
"loss_pdi": 0.6169473230838776,
"loss_ee": 0.8132463910362937,
"loss_delivery": 0.72551099143245,
"loss_biodist": 0.46287189017642627,
"loss_toxic": 0.14732415906407617
},
{
"loss": 3.1349260156804863,
"loss_size": 0.27535233172503387,
"loss_pdi": 0.5893999202684923,
"loss_ee": 0.807813747362657,
"loss_delivery": 0.9538742283528502,
"loss_biodist": 0.4018077091737227,
"loss_toxic": 0.1066781035201116
},
{
"loss": 2.6963415037501943,
"loss_size": 0.276088840582154,
"loss_pdi": 0.5490301495248621,
"loss_ee": 0.757920276034962,
"loss_delivery": 0.6706068068742752,
"loss_biodist": 0.34504769336093555,
"loss_toxic": 0.09764780781485817
},
{
"loss": 2.418043158271096,
"loss_size": 0.25799565559083765,
"loss_pdi": 0.5286295684901151,
"loss_ee": 0.73835120417855,
"loss_delivery": 0.5078353543173183,
"loss_biodist": 0.2939920425415039,
"loss_toxic": 0.09123929352922873
},
{
"loss": 2.294130650433627,
"loss_size": 0.20914554325017062,
"loss_pdi": 0.5159178945151243,
"loss_ee": 0.7331724437800321,
"loss_delivery": 0.50414734875614,
"loss_biodist": 0.2559647980061444,
"loss_toxic": 0.07578264227644964
},
{
"loss": 2.260723189874129,
"loss_size": 0.21194299920038742,
"loss_pdi": 0.51129734787074,
"loss_ee": 0.71018939668482,
"loss_delivery": 0.506003974513574,
"loss_biodist": 0.24507361785932022,
"loss_toxic": 0.07621594924818385
},
{
"loss": 2.180013732476668,
"loss_size": 0.21782933243296362,
"loss_pdi": 0.5094848789952018,
"loss_ee": 0.6991291533816945,
"loss_delivery": 0.4610004154118625,
"loss_biodist": 0.22516351396387274,
"loss_toxic": 0.06740644057704644
},
{
"loss": 2.131091995672746,
"loss_size": 0.22081551971760663,
"loss_pdi": 0.4984923790801655,
"loss_ee": 0.6744041009382769,
"loss_delivery": 0.44286114688624034,
"loss_biodist": 0.2276345125653527,
"loss_toxic": 0.0668843225999312
},
{
"loss": 2.075855114243247,
"loss_size": 0.1967273937030272,
"loss_pdi": 0.4761518023230813,
"loss_ee": 0.6580501876094125,
"loss_delivery": 0.4651151258837093,
"loss_biodist": 0.22098628905686465,
"loss_toxic": 0.05882429365407337
},
{
"loss": 2.070832209153609,
"loss_size": 0.22619221427223898,
"loss_pdi": 0.46735330332409253,
"loss_ee": 0.658088050105355,
"loss_delivery": 0.46364551173015073,
"loss_biodist": 0.1992648494514552,
"loss_toxic": 0.0562882690097798
},
{
"loss": 2.0497749502008613,
"loss_size": 0.2018005665053021,
"loss_pdi": 0.44933846592903137,
"loss_ee": 0.6409396637569774,
"loss_delivery": 0.4817944710904902,
"loss_biodist": 0.21351950141516599,
"loss_toxic": 0.0623823038556359
},
{
"loss": 1.998817953196439,
"loss_size": 0.19020642136985605,
"loss_pdi": 0.4579390991817821,
"loss_ee": 0.6345709101720289,
"loss_delivery": 0.45773128284649417,
"loss_biodist": 0.20633393864740024,
"loss_toxic": 0.052036324177276
},
{
"loss": 1.9732873006300493,
"loss_size": 0.18214904246005145,
"loss_pdi": 0.46314583312381397,
"loss_ee": 0.6480948545716025,
"loss_delivery": 0.43798652359030465,
"loss_biodist": 0.19162344119765543,
"loss_toxic": 0.05028762956234542
}
],
"val": [
{
"loss": 11.670351346333822,
"loss_size": 7.447449843088786,
"loss_pdi": 1.082938313484192,
"loss_ee": 0.9422469735145569,
"loss_delivery": 0.7185012102127075,
"loss_biodist": 0.945093979438146,
"loss_toxic": 0.5341211756070455,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 3.7281323273976645,
"loss_size": 0.4031377931435903,
"loss_pdi": 0.7136062383651733,
"loss_ee": 0.7757165829340616,
"loss_delivery": 0.6889261901378632,
"loss_biodist": 0.8803721169630686,
"loss_toxic": 0.2663734555244446,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 2.996154228846232,
"loss_size": 0.08968868106603622,
"loss_pdi": 0.5251734455426534,
"loss_ee": 0.705430785814921,
"loss_delivery": 0.7259814739227295,
"loss_biodist": 0.8580058316389719,
"loss_toxic": 0.09187404563029607,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 3.0241867701212564,
"loss_size": 0.06373066206773122,
"loss_pdi": 0.5075281461079916,
"loss_ee": 0.6890556613604227,
"loss_delivery": 0.8740459084510803,
"loss_biodist": 0.8290574749310812,
"loss_toxic": 0.060768917202949524,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 3.2901877562204995,
"loss_size": 0.20215384662151337,
"loss_pdi": 0.5433482428391775,
"loss_ee": 0.7523069183031718,
"loss_delivery": 0.9290379285812378,
"loss_biodist": 0.8056556979815165,
"loss_toxic": 0.057684975365797676,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 3.11362091700236,
"loss_size": 0.11133595556020737,
"loss_pdi": 0.5679469505945841,
"loss_ee": 0.8252793351809183,
"loss_delivery": 0.8343843619028727,
"loss_biodist": 0.7218613227208456,
"loss_toxic": 0.052812947581211724,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8181818181818182,
"acc_toxic": 1.0
},
{
"loss": 3.247321446736654,
"loss_size": 0.09277657171090443,
"loss_pdi": 0.6567125717798868,
"loss_ee": 1.0444208979606628,
"loss_delivery": 0.7062844236691793,
"loss_biodist": 0.6986102362473806,
"loss_toxic": 0.04851680745681127,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.45454545454545453,
"acc_toxic": 1.0
},
{
"loss": 3.168424208958944,
"loss_size": 0.05177713930606842,
"loss_pdi": 0.5932339330514272,
"loss_ee": 0.968136191368103,
"loss_delivery": 0.7594618300596873,
"loss_biodist": 0.7631273567676544,
"loss_toxic": 0.032687741021315254,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.45454545454545453,
"acc_toxic": 1.0
},
{
"loss": 3.1226733525594077,
"loss_size": 0.16442706187566122,
"loss_pdi": 0.4861932198206584,
"loss_ee": 0.8927785356839498,
"loss_delivery": 0.809323231379191,
"loss_biodist": 0.7391951779524485,
"loss_toxic": 0.030756143853068352,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.4696969696969697,
"acc_toxic": 1.0
},
{
"loss": 3.4750285943349204,
"loss_size": 0.07188746457298596,
"loss_pdi": 0.635328451792399,
"loss_ee": 1.0510863463083904,
"loss_delivery": 0.9019280473391215,
"loss_biodist": 0.7850770453612009,
"loss_toxic": 0.02972123461465041,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 3.3856334686279297,
"loss_size": 0.0857744167248408,
"loss_pdi": 0.5498887598514557,
"loss_ee": 0.9292206565539042,
"loss_delivery": 1.003569980462392,
"loss_biodist": 0.7962689697742462,
"loss_toxic": 0.020910644593338173,
"acc_pdi": 0.8333333333333334,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 3.594546397527059,
"loss_size": 0.058230139315128326,
"loss_pdi": 0.7094775040944418,
"loss_ee": 1.134681224822998,
"loss_delivery": 0.8755488395690918,
"loss_biodist": 0.7928893665472666,
"loss_toxic": 0.023719362293680508,
"acc_pdi": 0.6363636363636364,
"acc_ee": 0.21212121212121213,
"acc_toxic": 1.0
},
{
"loss": 3.34331480662028,
"loss_size": 0.09952588627735774,
"loss_pdi": 0.5444782872994741,
"loss_ee": 0.9434934655825297,
"loss_delivery": 0.9477877616882324,
"loss_biodist": 0.7897172272205353,
"loss_toxic": 0.018312191901107628,
"acc_pdi": 0.7878787878787878,
"acc_ee": 0.36363636363636365,
"acc_toxic": 1.0
},
{
"loss": 3.4212222894032798,
"loss_size": 0.08121616393327713,
"loss_pdi": 0.5517565310001373,
"loss_ee": 1.0685155193010967,
"loss_delivery": 0.874200721581777,
"loss_biodist": 0.828464408715566,
"loss_toxic": 0.017068898615737755,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.25757575757575757,
"acc_toxic": 1.0
},
{
"loss": 3.6395487785339355,
"loss_size": 0.07888962080081303,
"loss_pdi": 0.5913220842679342,
"loss_ee": 1.0437468489011128,
"loss_delivery": 1.0660852392514546,
"loss_biodist": 0.8421931266784668,
"loss_toxic": 0.01731194742023945,
"acc_pdi": 0.7878787878787878,
"acc_ee": 0.2727272727272727,
"acc_toxic": 1.0
},
{
"loss": 3.5140305360158286,
"loss_size": 0.0599971575041612,
"loss_pdi": 0.5561938285827637,
"loss_ee": 1.0674984057744343,
"loss_delivery": 0.9739653070767721,
"loss_biodist": 0.8400343159834543,
"loss_toxic": 0.01634151643762986,
"acc_pdi": 0.7424242424242424,
"acc_ee": 0.22727272727272727,
"acc_toxic": 1.0
},
{
"loss": 3.628753344217936,
"loss_size": 0.08887146785855293,
"loss_pdi": 0.582503984371821,
"loss_ee": 1.114095131556193,
"loss_delivery": 0.9745903412501017,
"loss_biodist": 0.8516232868035635,
"loss_toxic": 0.017069284183283646,
"acc_pdi": 0.7121212121212122,
"acc_ee": 0.19696969696969696,
"acc_toxic": 1.0
},
{
"loss": 3.6391177972157798,
"loss_size": 0.08282352735598882,
"loss_pdi": 0.581497848033905,
"loss_ee": 1.141685386498769,
"loss_delivery": 0.9548555612564087,
"loss_biodist": 0.8643284440040588,
"loss_toxic": 0.013927079737186432,
"acc_pdi": 0.7424242424242424,
"acc_ee": 0.16666666666666666,
"acc_toxic": 1.0
}
]
}

Binary file not shown.

View File

@ -0,0 +1,294 @@
{
"fold_results": [
{
"fold_idx": 0,
"n_samples": 95,
"size": {
"n": 95,
"rmse": 0.5909209144067168,
"mae": 0.376253614927593,
"r2": 0.005712927161997228
},
"delivery": {
"n": 66,
"rmse": 1.3280577883458438,
"mae": 0.5195405159964028,
"r2": 0.03195999739694366
},
"pdi": {
"n": 95,
"accuracy": 0.6105263157894737,
"precision": 0.20350877192982456,
"recall": 0.3333333333333333,
"f1": 0.25272331154684097
},
"ee": {
"n": 95,
"accuracy": 0.6736842105263158,
"precision": 0.22456140350877193,
"recall": 0.3333333333333333,
"f1": 0.26834381551362685
},
"toxic": {
"n": 66,
"accuracy": 0.8939393939393939,
"precision": 0.44696969696969696,
"recall": 0.5,
"f1": 0.472
},
"biodist": {
"n": 66,
"kl_divergence": 0.851655784204727,
"js_divergence": 0.21404831573756974
}
},
{
"fold_idx": 1,
"n_samples": 195,
"size": {
"n": 193,
"rmse": 0.4425801645813746,
"mae": 0.26432527161632796,
"r2": -0.026225211870033682
},
"delivery": {
"n": 123,
"rmse": 0.7771322048436382,
"mae": 0.6133777339870822,
"r2": -0.128644776760948
},
"pdi": {
"n": 195,
"accuracy": 0.7076923076923077,
"precision": 0.35384615384615387,
"recall": 0.5,
"f1": 0.4144144144144144
},
"ee": {
"n": 195,
"accuracy": 0.4205128205128205,
"precision": 0.14017094017094017,
"recall": 0.3333333333333333,
"f1": 0.19735258724428398
},
"toxic": {
"n": 123,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n": 123,
"kl_divergence": 0.9336461102028436,
"js_divergence": 0.24870266224462317
}
},
{
"fold_idx": 2,
"n_samples": 51,
"size": {
"n": 51,
"rmse": 0.6473513298834871,
"mae": 0.5600235602434944,
"r2": -9.27515642706235
},
"delivery": {
"n": 44,
"rmse": 0.7721077356414991,
"mae": 0.6167582499593581,
"r2": -0.4822886602727561
},
"pdi": {
"n": 51,
"accuracy": 0.8823529411764706,
"precision": 0.29411764705882354,
"recall": 0.3333333333333333,
"f1": 0.3125
},
"ee": {
"n": 51,
"accuracy": 0.8431372549019608,
"precision": 0.28104575163398693,
"recall": 0.3333333333333333,
"f1": 0.3049645390070922
},
"toxic": {
"n": 47,
"accuracy": 0.851063829787234,
"precision": 0.425531914893617,
"recall": 0.5,
"f1": 0.4597701149425288
},
"biodist": {
"n": 45,
"kl_divergence": 1.1049896129018548,
"js_divergence": 0.25485248115851133
}
},
{
"fold_idx": 3,
"n_samples": 66,
"size": {
"n": 66,
"rmse": 0.2407212117920812,
"mae": 0.19363613562150436,
"r2": -0.11204941379936861
},
"delivery": {
"n": 62,
"rmse": 1.0041711455927012,
"mae": 0.7132550483914993,
"r2": -0.63265374674746
},
"pdi": {
"n": 66,
"accuracy": 0.8484848484848485,
"precision": 0.42424242424242425,
"recall": 0.5,
"f1": 0.4590163934426229
},
"ee": {
"n": 66,
"accuracy": 0.8181818181818182,
"precision": 0.27692307692307694,
"recall": 0.32727272727272727,
"f1": 0.3
},
"toxic": {
"n": 62,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n": 62,
"kl_divergence": 0.9677978984139058,
"js_divergence": 0.2020309307244639
}
},
{
"fold_idx": 4,
"n_samples": 27,
"size": {
"n": 27,
"rmse": 0.23392834445509142,
"mae": 0.19066280788845485,
"r2": -0.2667651950955112
},
"delivery": {
"n": 15,
"rmse": 1.9603892288630869,
"mae": 1.3892907698949177,
"r2": -0.29760739742916287
},
"pdi": {
"n": 27,
"accuracy": 0.8888888888888888,
"precision": 0.4444444444444444,
"recall": 0.5,
"f1": 0.47058823529411764
},
"ee": {
"n": 27,
"accuracy": 0.5925925925925926,
"precision": 0.19753086419753085,
"recall": 0.3333333333333333,
"f1": 0.24806201550387597
},
"toxic": {
"n": 15,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n": 15,
"kl_divergence": 0.9389607012315264,
"js_divergence": 0.2470218476598176
}
}
],
"summary_stats": {
"size": {
"rmse_mean": 0.43110039302375025,
"rmse_std": 0.17179051271013462,
"r2_mean": -1.9348966641330534,
"r2_std": 3.6713441784129
},
"delivery": {
"rmse_mean": 1.1683716206573538,
"rmse_std": 0.4449374578352648,
"r2_mean": -0.30184691676267666,
"r2_std": 0.23809090378746706
},
"pdi": {
"accuracy_mean": 0.7875890604063979,
"accuracy_std": 0.11016791908756088,
"f1_mean": 0.3818484709395992,
"f1_std": 0.08529090446864619
},
"ee": {
"accuracy_mean": 0.6696217393431015,
"accuracy_std": 0.15503740047242787,
"f1_mean": 0.2637445914537758,
"f1_std": 0.039213602228007696
},
"toxic": {
"accuracy_mean": 0.9490006447453256,
"accuracy_std": 0.06391582554207781,
"f1_mean": 0.7863540229885058,
"f1_std": 0.26169039387919035
},
"biodist": {
"kl_mean": 0.9594100213909715,
"kl_std": 0.08240959093662605,
"js_mean": 0.23333124750499712,
"js_std": 0.021158533549255752
}
},
"overall": {
"size": {
"n_samples": 432,
"mse": 0.22604480336185886,
"rmse": 0.47544169291497657,
"mae": 0.3084443360567093,
"r2": -0.2313078534105617
},
"delivery": {
"n_samples": 310,
"mse": 1.0873755440675295,
"rmse": 1.0427730069710903,
"mae": 0.6513989447841361,
"r2": -0.09443640807387799
},
"pdi": {
"n_samples": 434,
"accuracy": 0.7396313364055299,
"precision": 0.18490783410138248,
"recall": 0.25,
"f1": 0.21258278145695364
},
"ee": {
"n_samples": 434,
"accuracy": 0.5967741935483871,
"precision": 0.1993841416474211,
"recall": 0.33205128205128204,
"f1": 0.24915824915824913
},
"toxic": {
"n_samples": 313,
"accuracy": 0.9552715654952076,
"precision": 0.4776357827476038,
"recall": 0.5,
"f1": 0.48856209150326796
},
"biodist": {
"n_samples": 311,
"kl_divergence": 0.9481034280166569,
"js_divergence": 0.23285280825310384
}
}
}

Binary file not shown.

View File

@ -1,310 +1,206 @@
{
"train": [
{
"loss": 0.8244676398801744,
"n_samples": 8721
"loss": 0.7730368412685099,
"n_samples": 6783
},
{
"loss": 0.6991508170533461,
"n_samples": 8721
"loss": 0.658895703010919,
"n_samples": 6783
},
{
"loss": 0.6388374940987616,
"n_samples": 8721
"loss": 0.6059015260392299,
"n_samples": 6783
},
{
"loss": 0.6008581508669937,
"n_samples": 8721
"loss": 0.5744731174349416,
"n_samples": 6783
},
{
"loss": 0.584832567446085,
"n_samples": 8721
"loss": 0.5452056020458733,
"n_samples": 6783
},
{
"loss": 0.5481657371815157,
"n_samples": 8721
"loss": 0.5138543470936083,
"n_samples": 6783
},
{
"loss": 0.5368926340308079,
"n_samples": 8721
"loss": 0.4885380559178135,
"n_samples": 6783
},
{
"loss": 0.5210388793613561,
"n_samples": 8721
"loss": 0.47587182296687974,
"n_samples": 6783
},
{
"loss": 0.49758357966374045,
"n_samples": 8721
"loss": 0.4671051038255316,
"n_samples": 6783
},
{
"loss": 0.49256294099457043,
"n_samples": 8721
"loss": 0.46794115915756107,
"n_samples": 6783
},
{
"loss": 0.4697267088016886,
"n_samples": 8721
"loss": 0.4293930456997915,
"n_samples": 6783
},
{
"loss": 0.45763822707571084,
"n_samples": 8721
"loss": 0.42624105651716415,
"n_samples": 6783
},
{
"loss": 0.4495221330627172,
"n_samples": 8721
"loss": 0.4131358770446828,
"n_samples": 6783
},
{
"loss": 0.446159594079631,
"n_samples": 8721
"loss": 0.3946074267790835,
"n_samples": 6783
},
{
"loss": 0.4327090857889029,
"n_samples": 8721
"loss": 0.3898155013755344,
"n_samples": 6783
},
{
"loss": 0.4249273364101852,
"n_samples": 8721
"loss": 0.37861797005733383,
"n_samples": 6783
},
{
"loss": 0.4216959138704459,
"n_samples": 8721
"loss": 0.3775682858392304,
"n_samples": 6783
},
{
"loss": 0.416526201182502,
"n_samples": 8721
"loss": 0.3800349080262064,
"n_samples": 6783
},
{
"loss": 0.40368679039741573,
"n_samples": 8721
"loss": 0.36302345173031675,
"n_samples": 6783
},
{
"loss": 0.4051084730032182,
"n_samples": 8721
"loss": 0.3429561740842766,
"n_samples": 6783
},
{
"loss": 0.38971701020385785,
"n_samples": 8721
"loss": 0.3445638883004898,
"n_samples": 6783
},
{
"loss": 0.39155546386038786,
"n_samples": 8721
"loss": 0.318970229203733,
"n_samples": 6783
},
{
"loss": 0.37976963541784114,
"n_samples": 8721
"loss": 0.30179278279904437,
"n_samples": 6783
},
{
"loss": 0.36484339719805037,
"n_samples": 8721
"loss": 0.2887343142006437,
"n_samples": 6783
},
{
"loss": 0.36232607571196496,
"n_samples": 8721
},
{
"loss": 0.3345973272380199,
"n_samples": 8721
},
{
"loss": 0.31767916518768957,
"n_samples": 8721
},
{
"loss": 0.32065429246052457,
"n_samples": 8721
},
{
"loss": 0.3171297926146043,
"n_samples": 8721
},
{
"loss": 0.3122120894173009,
"n_samples": 8721
},
{
"loss": 0.3135035038404461,
"n_samples": 8721
},
{
"loss": 0.2987745178222875,
"n_samples": 8721
},
{
"loss": 0.2914867957853393,
"n_samples": 8721
},
{
"loss": 0.2983839795507705,
"n_samples": 8721
},
{
"loss": 0.2826709597875678,
"n_samples": 8721
},
{
"loss": 0.2731766632569382,
"n_samples": 8721
},
{
"loss": 0.27726896305742266,
"n_samples": 8721
},
{
"loss": 0.27864557847067956,
"n_samples": 8721
"loss": 0.29240367556855545,
"n_samples": 6783
}
],
"val": [
{
"loss": 0.7601077516012517,
"n_samples": 969
"loss": 0.7350345371841441,
"n_samples": 2907
},
{
"loss": 0.7119935319611901,
"n_samples": 969
"loss": 0.7165568811318536,
"n_samples": 2907
},
{
"loss": 0.6461842978148269,
"n_samples": 969
"loss": 0.7251406249862214,
"n_samples": 2907
},
{
"loss": 0.7006978391063226,
"n_samples": 969
"loss": 0.6836505264587159,
"n_samples": 2907
},
{
"loss": 0.6533874032943979,
"n_samples": 969
"loss": 0.6747132955771933,
"n_samples": 2907
},
{
"loss": 0.6413641451743611,
"n_samples": 969
"loss": 0.6691136244936912,
"n_samples": 2907
},
{
"loss": 0.6168395132979742,
"n_samples": 969
"loss": 0.6337480902323249,
"n_samples": 2907
},
{
"loss": 0.6095251602162025,
"n_samples": 969
"loss": 0.6600317959527934,
"n_samples": 2907
},
{
"loss": 0.5887809592626905,
"n_samples": 969
"loss": 0.6439923948855346,
"n_samples": 2907
},
{
"loss": 0.5655298325376368,
"n_samples": 969
"loss": 0.643800035575267,
"n_samples": 2907
},
{
"loss": 0.5809201743872788,
"n_samples": 969
"loss": 0.6181512585221839,
"n_samples": 2907
},
{
"loss": 0.5897585974912033,
"n_samples": 969
"loss": 0.6442458634939151,
"n_samples": 2907
},
{
"loss": 0.5732012489662573,
"n_samples": 969
"loss": 0.6344759362359862,
"n_samples": 2907
},
{
"loss": 0.5607388911786094,
"n_samples": 969
"loss": 0.6501405371457472,
"n_samples": 2907
},
{
"loss": 0.5717580675371414,
"n_samples": 969
"loss": 0.6098835162990152,
"n_samples": 2907
},
{
"loss": 0.5553950037657291,
"n_samples": 969
"loss": 0.6366627322138894,
"n_samples": 2907
},
{
"loss": 0.5778171792857049,
"n_samples": 969
"loss": 0.6171610150646417,
"n_samples": 2907
},
{
"loss": 0.5602665468127734,
"n_samples": 969
"loss": 0.6358801012273748,
"n_samples": 2907
},
{
"loss": 0.5475307451359259,
"n_samples": 969
"loss": 0.6239976831059871,
"n_samples": 2907
},
{
"loss": 0.551515599314827,
"n_samples": 969
"loss": 0.6683828232827201,
"n_samples": 2907
},
{
"loss": 0.5755438121541243,
"n_samples": 969
"loss": 0.6655785786478143,
"n_samples": 2907
},
{
"loss": 0.5798238261811381,
"n_samples": 969
"loss": 0.6152775046503088,
"n_samples": 2907
},
{
"loss": 0.5739961433828923,
"n_samples": 969
"loss": 0.6202247662153858,
"n_samples": 2907
},
{
"loss": 0.5742599932540312,
"n_samples": 969
"loss": 0.648199727435189,
"n_samples": 2907
},
{
"loss": 0.5834948123885382,
"n_samples": 969
},
{
"loss": 0.554078846570139,
"n_samples": 969
},
{
"loss": 0.5714933996322354,
"n_samples": 969
},
{
"loss": 0.5384107524350331,
"n_samples": 969
},
{
"loss": 0.570854394451568,
"n_samples": 969
},
{
"loss": 0.5767292551642478,
"n_samples": 969
},
{
"loss": 0.5660079547556808,
"n_samples": 969
},
{
"loss": 0.5608972411514312,
"n_samples": 969
},
{
"loss": 0.5620947442987263,
"n_samples": 969
},
{
"loss": 0.5706970894361305,
"n_samples": 969
},
{
"loss": 0.5702376298690974,
"n_samples": 969
},
{
"loss": 0.5758474825259579,
"n_samples": 969
},
{
"loss": 0.5673816067284844,
"n_samples": 969
},
{
"loss": 0.5671441179879925,
"n_samples": 969
"loss": 0.6473217075085124,
"n_samples": 2907
}
]
}

View File

@ -1,232 +1,224 @@
"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式"""
"""内部数据 Cross-Validation 划分脚本:基于 Amine 的分组划分"""
from pathlib import Path
from typing import Dict, List, Tuple
from typing import List
import numpy as np
import pandas as pd
import typer
from loguru import logger
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
from lnp_ml.config import INTERIM_DATA_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import (
LNPDatasetConfig,
process_dataframe,
SMILES_COL,
COMP_COLS,
HELP_COLS,
TARGET_REGRESSION,
TARGET_CLASSIFICATION_PDI,
TARGET_CLASSIFICATION_EE,
TARGET_TOXIC,
TARGET_BIODIST,
get_phys_cols,
get_exp_cols,
EXP_ONEHOT_SPECS,
PHYS_ONEHOT_SPECS,
)
app = typer.Typer()
# CV extra_x 列名到模型列名的映射
CV_COL_MAPPING = {
# Batch_or_individual_or_barcoded -> Sample_organization_type (for Value_name related)
"Batch_or_individual_or_barcoded_Barcoded": "Batch_or_individual_or_barcoded_Barcoded",
"Batch_or_individual_or_barcoded_Individual": "Batch_or_individual_or_barcoded_Individual",
# Helper_lipid_ID_None 不在模型中使用,忽略
}
def load_cv_split(cv_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
def amine_based_cv_split(
df: pd.DataFrame,
n_folds: int = 5,
seed: int = 42,
amine_col: str = "Amine",
) -> List[dict]:
"""
加载单个 CV split 的数据
基于 Amine 列进行 Cross-Validation 划分
步骤
1. amine_col 分组
2. 打乱分组顺序
3. 将分组 round-robin 分配到 n_folds 个容器
4. 对于每个 fold i
- validation = container[i]
- test = container[(i+1) % n_folds]
- train = 其余所有
Args:
cv_dir: CV split 目录 cv_0/
df: 输入 DataFrame
n_folds: 折数
seed: 随机种子
amine_col: 用于分组的列名
Returns:
(train_df, valid_df, test_df) 合并后的 DataFrame
List of dicts每个 dict 包含 train_df, val_df, test_df
"""
splits = {}
for split_name in ["train", "valid", "test"]:
# 加载主数据smiles, quantified_delivery
main_path = cv_dir / f"{split_name}.csv"
extra_x_path = cv_dir / f"{split_name}_extra_x.csv"
metadata_path = cv_dir / f"{split_name}_metadata.csv"
# 获取唯一的 amine 并打乱
unique_amines = df[amine_col].unique()
rng = np.random.RandomState(seed)
rng.shuffle(unique_amines)
if not main_path.exists():
raise FileNotFoundError(f"Missing {main_path}")
logger.info(f"Found {len(unique_amines)} unique amines")
main_df = pd.read_csv(main_path)
# Round-robin 分配到 n_folds 个容器
containers = [[] for _ in range(n_folds)]
for i, amine in enumerate(unique_amines):
containers[i % n_folds].append(amine)
# 加载 extra_x已 one-hot 编码的特征)
if extra_x_path.exists():
extra_x_df = pd.read_csv(extra_x_path)
# 确保行数一致
assert len(main_df) == len(extra_x_df), f"Row count mismatch: {len(main_df)} vs {len(extra_x_df)}"
# 合并(按行索引)
df = pd.concat([main_df, extra_x_df], axis=1)
else:
df = main_df
logger.warning(f" {split_name}_extra_x.csv not found, using main data only")
# 打印每个容器的大小
for i, container in enumerate(containers):
container_samples = df[df[amine_col].isin(container)]
logger.info(f" Container {i}: {len(container)} amines, {len(container_samples)} samples")
# 可选:从 metadata 获取额外信息
if metadata_path.exists():
metadata_df = pd.read_csv(metadata_path)
# 提取需要的列(如 Purity, Mix_type, Value_name 等)
for col in ["Purity", "Mix_type", "Value_name", "Target_or_delivered_gene"]:
if col in metadata_df.columns and col not in df.columns:
df[col] = metadata_df[col]
# 生成每个 fold 的数据
fold_splits = []
for i in range(n_folds):
val_amines = set(containers[i])
test_amines = set(containers[(i + 1) % n_folds])
train_amines = set()
for j in range(n_folds):
if j != i and j != (i + 1) % n_folds:
train_amines.update(containers[j])
splits[split_name] = df
train_df = df[df[amine_col].isin(train_amines)].reset_index(drop=True)
val_df = df[df[amine_col].isin(val_amines)].reset_index(drop=True)
test_df = df[df[amine_col].isin(test_amines)].reset_index(drop=True)
return splits["train"], splits["valid"], splits["test"]
fold_splits.append({
"train": train_df,
"val": val_df,
"test": test_df,
})
logger.info(
f"Fold {i}: train={len(train_df)} ({len(train_amines)} amines), "
f"val={len(val_df)} ({len(val_amines)} amines), "
f"test={len(test_df)} ({len(test_amines)} amines)"
)
def process_cv_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
处理 CV 数据的 DataFrame对齐到模型所需的列格式
CV 数据的 extra_x 已经包含大部分 one-hot 编码但需要
1. 添加缺失的 one-hot 设为 0
2. metadata 中生成 phys token one-hot Purity, Mix_type, Cargo_type, Target_or_delivered_gene
3. 生成 Value_name one-hot
"""
df = df.copy()
# 1. 处理 comp 列
for col in COMP_COLS:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
else:
df[col] = 0.0
# 2. 处理 help 列
for col in HELP_COLS:
if col not in df.columns:
df[col] = 0.0
else:
df[col] = df[col].fillna(0.0).astype(float)
# 3. 处理 phys token 的 one-hot 列
for col, values in PHYS_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 4. 处理 exp token 的 one-hot 列
for col, values in EXP_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 5. 处理 quantified_delivery
if "quantified_delivery" in df.columns:
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
return df
def get_feature_columns() -> List[str]:
"""获取所有特征列名"""
config = LNPDatasetConfig()
return (
["smiles"]
+ config.comp_cols
+ config.phys_cols
+ config.help_cols
+ config.exp_cols
+ ["quantified_delivery"]
)
return fold_splits
@app.command()
def main(
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
output_dir: Path = PROCESSED_DATA_DIR / "cv",
n_folds: int = 5,
seed: int = 42,
amine_col: str = "Amine",
):
"""
处理 cross-validation 数据生成模型所需的 parquet 文件
基于 Amine 分组进行 Cross-Validation 数据划分
采用类似 scaffold splitting 的思路将相同 Amine 的数据放在同一组
确保训练集和测试集之间没有 Amine 泄露
划分比例约为 train:val:test 3:1:1
输出结构:
- processed/cv/fold_0/train.parquet
- processed/cv/fold_0/valid.parquet
- processed/cv/fold_0/val.parquet
- processed/cv/fold_0/test.parquet
- processed/cv/fold_1/...
- processed/cv/feature_columns.txt
"""
logger.info(f"Processing CV data from {data_dir}")
logger.info(f"Loading data from {input_path}")
df = pd.read_csv(input_path)
logger.info(f"Loaded {len(df)} samples")
# 获取所有 cv_* 目录
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
if len(cv_dirs) == 0:
logger.error(f"No cv_* directories found in {data_dir}")
# 检查 amine 列是否存在
if amine_col not in df.columns:
logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(df.columns)}")
raise typer.Exit(1)
if len(cv_dirs) != n_folds:
logger.warning(f"Expected {n_folds} folds, found {len(cv_dirs)}")
# 处理数据列对齐、one-hot 生成等)
logger.info("Processing dataframe...")
df = process_dataframe(df)
logger.info(f"Found {len(cv_dirs)} folds: {[d.name for d in cv_dirs]}")
# 确保 Amine 列仍然存在process_dataframe 可能不会保留它)
# 重新加载原始数据获取 Amine 列
original_df = pd.read_csv(input_path)
if amine_col in original_df.columns and amine_col not in df.columns:
df[amine_col] = original_df[amine_col].values
feature_cols = get_feature_columns()
# 定义要保留的列
phys_cols = get_phys_cols()
exp_cols = get_exp_cols()
keep_cols = (
[SMILES_COL]
+ COMP_COLS
+ phys_cols
+ HELP_COLS
+ exp_cols
+ TARGET_REGRESSION
+ TARGET_CLASSIFICATION_PDI
+ TARGET_CLASSIFICATION_EE
+ [TARGET_TOXIC]
+ TARGET_BIODIST
)
# 只保留存在的列
keep_cols = [c for c in keep_cols if c in df.columns]
# 进行 CV 划分
logger.info(f"\nPerforming {n_folds}-fold amine-based CV split (seed={seed})...")
fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col)
# 保存每个 fold
output_dir.mkdir(parents=True, exist_ok=True)
for i, cv_dir in enumerate(cv_dirs):
fold_name = f"fold_{i}"
fold_output_dir = output_dir / fold_name
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Processing {cv_dir.name} -> {fold_name}")
# 加载数据
train_df, valid_df, test_df = load_cv_split(cv_dir)
logger.info(f" Loaded: train={len(train_df)}, valid={len(valid_df)}, test={len(test_df)}")
# 处理数据
train_df = process_cv_dataframe(train_df)
valid_df = process_cv_dataframe(valid_df)
test_df = process_cv_dataframe(test_df)
# 确保所有列存在
for col in feature_cols:
for df in [train_df, valid_df, test_df]:
if col not in df.columns:
df[col] = 0.0 if col != "smiles" else ""
for i, split in enumerate(fold_splits):
fold_dir = output_dir / f"fold_{i}"
fold_dir.mkdir(parents=True, exist_ok=True)
# 只保留需要的列
train_df = train_df[feature_cols]
valid_df = valid_df[feature_cols]
test_df = test_df[feature_cols]
train_df = split["train"][keep_cols].reset_index(drop=True)
val_df = split["val"][keep_cols].reset_index(drop=True)
test_df = split["test"][keep_cols].reset_index(drop=True)
# 保存
train_df.to_parquet(fold_output_dir / "train.parquet", index=False)
valid_df.to_parquet(fold_output_dir / "valid.parquet", index=False)
test_df.to_parquet(fold_output_dir / "test.parquet", index=False)
train_df.to_parquet(fold_dir / "train.parquet", index=False)
val_df.to_parquet(fold_dir / "val.parquet", index=False)
test_df.to_parquet(fold_dir / "test.parquet", index=False)
logger.success(f" Saved to {fold_output_dir}")
logger.success(f"Saved fold {i} to {fold_dir}")
# 保存特征列配置
cols_path = output_dir / "feature_columns.txt"
with open(cols_path, "w") as f:
f.write("\n".join(feature_cols))
logger.success(f"Saved feature columns to {cols_path}")
# 保存列名配置
config_path = output_dir / "feature_columns.txt"
with open(config_path, "w") as f:
f.write("# Feature columns configuration\n\n")
f.write(f"# SMILES\n{SMILES_COL}\n\n")
f.write(f"# comp token [{len(COMP_COLS)}]\n")
f.write("\n".join(COMP_COLS) + "\n\n")
f.write(f"# phys token [{len(phys_cols)}]\n")
f.write("\n".join(phys_cols) + "\n\n")
f.write(f"# help token [{len(HELP_COLS)}]\n")
f.write("\n".join(HELP_COLS) + "\n\n")
f.write(f"# exp token [{len(exp_cols)}]\n")
f.write("\n".join(exp_cols) + "\n\n")
f.write("# Targets\n")
f.write("## Regression\n")
f.write("\n".join(TARGET_REGRESSION) + "\n")
f.write("## PDI classification\n")
f.write("\n".join(TARGET_CLASSIFICATION_PDI) + "\n")
f.write("## EE classification\n")
f.write("\n".join(TARGET_CLASSIFICATION_EE) + "\n")
f.write("## Toxic\n")
f.write(f"{TARGET_TOXIC}\n")
f.write("## Biodistribution\n")
f.write("\n".join(TARGET_BIODIST) + "\n")
logger.success(f"Saved feature config to {config_path}")
# 打印汇总
logger.info("\n" + "=" * 60)
logger.info("CV DATA PROCESSING COMPLETE")
logger.info("=" * 60)
logger.info(f"Output directory: {output_dir}")
logger.info(f"Number of folds: {len(cv_dirs)}")
logger.info(f"Number of folds: {n_folds}")
logger.info(f"Splitting method: Amine-based (column: {amine_col})")
logger.info(f"Random seed: {seed}")
if __name__ == "__main__":

View File

@ -0,0 +1,234 @@
"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式"""
from pathlib import Path
from typing import Dict, List, Tuple
import pandas as pd
import typer
from loguru import logger
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import (
LNPDatasetConfig,
COMP_COLS,
HELP_COLS,
get_phys_cols,
get_exp_cols,
EXP_ONEHOT_SPECS,
PHYS_ONEHOT_SPECS,
)
app = typer.Typer()
# CV extra_x 列名到模型列名的映射
CV_COL_MAPPING = {
# Batch_or_individual_or_barcoded -> Sample_organization_type (for Value_name related)
"Batch_or_individual_or_barcoded_Barcoded": "Batch_or_individual_or_barcoded_Barcoded",
"Batch_or_individual_or_barcoded_Individual": "Batch_or_individual_or_barcoded_Individual",
# Helper_lipid_ID_None 不在模型中使用,忽略
}
def load_cv_split(cv_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
加载单个 CV split 的数据
Args:
cv_dir: CV split 目录 cv_0/
Returns:
(train_df, valid_df, test_df) 合并后的 DataFrame
"""
splits = {}
for split_name in ["train", "valid", "test"]:
# 加载主数据smiles, quantified_delivery
main_path = cv_dir / f"{split_name}.csv"
extra_x_path = cv_dir / f"{split_name}_extra_x.csv"
metadata_path = cv_dir / f"{split_name}_metadata.csv"
if not main_path.exists():
raise FileNotFoundError(f"Missing {main_path}")
main_df = pd.read_csv(main_path)
# 加载 extra_x已 one-hot 编码的特征)
if extra_x_path.exists():
extra_x_df = pd.read_csv(extra_x_path)
# 确保行数一致
assert len(main_df) == len(extra_x_df), f"Row count mismatch: {len(main_df)} vs {len(extra_x_df)}"
# 合并(按行索引)
df = pd.concat([main_df, extra_x_df], axis=1)
else:
df = main_df
logger.warning(f" {split_name}_extra_x.csv not found, using main data only")
# 可选:从 metadata 获取额外信息
if metadata_path.exists():
metadata_df = pd.read_csv(metadata_path)
# 提取需要的列(如 Purity, Mix_type, Value_name 等)
for col in ["Purity", "Mix_type", "Value_name", "Target_or_delivered_gene"]:
if col in metadata_df.columns and col not in df.columns:
df[col] = metadata_df[col]
splits[split_name] = df
return splits["train"], splits["valid"], splits["test"]
def process_cv_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
处理 CV 数据的 DataFrame对齐到模型所需的列格式
CV 数据的 extra_x 已经包含大部分 one-hot 编码但需要
1. 添加缺失的 one-hot 设为 0
2. metadata 中生成 phys token one-hot Purity, Mix_type, Cargo_type, Target_or_delivered_gene
3. 生成 Value_name one-hot
"""
df = df.copy()
# 1. 处理 comp 列
for col in COMP_COLS:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
else:
df[col] = 0.0
# 2. 处理 help 列
for col in HELP_COLS:
if col not in df.columns:
df[col] = 0.0
else:
df[col] = df[col].fillna(0.0).astype(float)
# 3. 处理 phys token 的 one-hot 列
for col, values in PHYS_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 4. 处理 exp token 的 one-hot 列
for col, values in EXP_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 5. 处理 quantified_delivery
if "quantified_delivery" in df.columns:
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
return df
def get_feature_columns() -> List[str]:
"""获取所有特征列名"""
config = LNPDatasetConfig()
return (
["smiles"]
+ config.comp_cols
+ config.phys_cols
+ config.help_cols
+ config.exp_cols
+ ["quantified_delivery"]
)
@app.command()
def main(
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
output_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
n_folds: int = 5,
):
"""
处理 cross-validation 数据生成模型所需的 parquet 文件
输出结构:
- processed/pretrain_cv/fold_0/train.parquet
- processed/pretrain_cv/fold_0/valid.parquet
- processed/pretrain_cv/fold_0/test.parquet
- processed/pretrain_cv/fold_1/...
- processed/pretrain_cv/feature_columns.txt
"""
logger.info(f"Processing CV data from {data_dir}")
# 获取所有 cv_* 目录
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
if len(cv_dirs) == 0:
logger.error(f"No cv_* directories found in {data_dir}")
raise typer.Exit(1)
if len(cv_dirs) != n_folds:
logger.warning(f"Expected {n_folds} folds, found {len(cv_dirs)}")
logger.info(f"Found {len(cv_dirs)} folds: {[d.name for d in cv_dirs]}")
feature_cols = get_feature_columns()
output_dir.mkdir(parents=True, exist_ok=True)
for i, cv_dir in enumerate(cv_dirs):
fold_name = f"fold_{i}"
fold_output_dir = output_dir / fold_name
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Processing {cv_dir.name} -> {fold_name}")
# 加载数据
train_df, valid_df, test_df = load_cv_split(cv_dir)
logger.info(f" Loaded: train={len(train_df)}, valid={len(valid_df)}, test={len(test_df)}")
# 处理数据
train_df = process_cv_dataframe(train_df)
valid_df = process_cv_dataframe(valid_df)
test_df = process_cv_dataframe(test_df)
# 确保所有列存在
for col in feature_cols:
for df in [train_df, valid_df, test_df]:
if col not in df.columns:
df[col] = 0.0 if col != "smiles" else ""
# 只保留需要的列
train_df = train_df[feature_cols]
valid_df = valid_df[feature_cols]
test_df = test_df[feature_cols]
# 保存
train_df.to_parquet(fold_output_dir / "train.parquet", index=False)
valid_df.to_parquet(fold_output_dir / "valid.parquet", index=False)
test_df.to_parquet(fold_output_dir / "test.parquet", index=False)
logger.success(f" Saved to {fold_output_dir}")
# 保存特征列配置
cols_path = output_dir / "feature_columns.txt"
with open(cols_path, "w") as f:
f.write("\n".join(feature_cols))
logger.success(f"Saved feature columns to {cols_path}")
logger.info("\n" + "=" * 60)
logger.info("CV DATA PROCESSING COMPLETE")
logger.info("=" * 60)
logger.info(f"Output directory: {output_dir}")
logger.info(f"Number of folds: {len(cv_dirs)}")
if __name__ == "__main__":
app()