Compare commits

..

No commits in common. "ac4246c2b7e17ee5a91f5f6d2c5551795f448dc1" and "e123fc8f3ebcb57d086a7d382afcf986c372257e" have entirely different histories.

54 changed files with 401 additions and 4073 deletions

View File

@ -74,11 +74,6 @@ data_pretrain: requirements
$(PYTHON_INTERPRETER) scripts/process_external.py
## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv)
.PHONY: data_pretrain_cv
data_pretrain_cv: requirements
$(PYTHON_INTERPRETER) scripts/process_external_cv.py
## Process internal data with amine-based CV splitting (interim -> processed/cv)
.PHONY: data_cv
data_cv: requirements
$(PYTHON_INTERPRETER) scripts/process_data_cv.py
@ -111,8 +106,8 @@ pretrain_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG)
## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint)
.PHONY: test_pretrain_cv
test_pretrain_cv: requirements
.PHONY: test_cv
test_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG)
## Train model (multi-task, from scratch)
@ -125,22 +120,6 @@ train: requirements
finetune: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
## Finetune with cross-validation on internal data (5-fold, amine-based split) with pretrained weights
.PHONY: finetune_cv
finetune_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv main --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
## Train with cross-validation on internal data only (5-fold, amine-based split)
.PHONY: train_cv
train_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv main $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
## Evaluate CV finetuned models on test sets (auto-detects MPNN from checkpoint)
.PHONY: test_cv
test_cv: requirements
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv test $(DEVICE_FLAG)
## Train with hyperparameter tuning
.PHONY: tune
tune: requirements

View File

@ -1,16 +1,9 @@
# Feature columns configuration
# SMILES
smiles
# comp token [5]
Cationic_Lipid_to_mRNA_weight_ratio
Cationic_Lipid_Mol_Ratio
Phospholipid_Mol_Ratio
Cholesterol_Mol_Ratio
PEG_Lipid_Mol_Ratio
# phys token [12]
Purity_Pure
Purity_Crude
Mix_type_Microfluidic
@ -23,14 +16,10 @@ Target_or_delivered_gene_Peptide_barcode
Target_or_delivered_gene_hEPO
Target_or_delivered_gene_FVII
Target_or_delivered_gene_GFP
# help token [4]
Helper_lipid_ID_DOPE
Helper_lipid_ID_DOTAP
Helper_lipid_ID_DSPC
Helper_lipid_ID_MDOA
# exp token [32]
Model_type_A549
Model_type_BDMC
Model_type_BMDM
@ -63,27 +52,4 @@ Value_name_hEPO
Value_name_FVII_silencing
Value_name_GFP_delivery
Value_name_Discretized_luminescence
# Targets
## Regression
size
quantified_delivery
## PDI classification
PDI_0_0to0_2
PDI_0_2to0_3
PDI_0_3to0_4
PDI_0_4to0_5
## EE classification
Encapsulation_Efficiency_EE<50
Encapsulation_Efficiency_50<=EE<80
Encapsulation_Efficiency_80<EE<=100
## Toxic
toxic
## Biodistribution
Biodistribution_lymph_nodes
Biodistribution_heart
Biodistribution_liver
Biodistribution_spleen
Biodistribution_lung
Biodistribution_kidney
Biodistribution_muscle
quantified_delivery

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,55 +0,0 @@
smiles
Cationic_Lipid_to_mRNA_weight_ratio
Cationic_Lipid_Mol_Ratio
Phospholipid_Mol_Ratio
Cholesterol_Mol_Ratio
PEG_Lipid_Mol_Ratio
Purity_Pure
Purity_Crude
Mix_type_Microfluidic
Mix_type_Pipetting
Cargo_type_mRNA
Cargo_type_pDNA
Cargo_type_siRNA
Target_or_delivered_gene_FFL
Target_or_delivered_gene_Peptide_barcode
Target_or_delivered_gene_hEPO
Target_or_delivered_gene_FVII
Target_or_delivered_gene_GFP
Helper_lipid_ID_DOPE
Helper_lipid_ID_DOTAP
Helper_lipid_ID_DSPC
Helper_lipid_ID_MDOA
Model_type_A549
Model_type_BDMC
Model_type_BMDM
Model_type_HBEC_ALI
Model_type_HEK293T
Model_type_HeLa
Model_type_IGROV1
Model_type_Mouse
Model_type_RAW264p7
Delivery_target_body
Delivery_target_dendritic_cell
Delivery_target_generic_cell
Delivery_target_liver
Delivery_target_lung
Delivery_target_lung_epithelium
Delivery_target_macrophage
Delivery_target_muscle
Delivery_target_spleen
Route_of_administration_in_vitro
Route_of_administration_intramuscular
Route_of_administration_intratracheal
Route_of_administration_intravenous
Batch_or_individual_or_barcoded_Barcoded
Batch_or_individual_or_barcoded_Individual
Value_name_log_luminescence
Value_name_luminescence
Value_name_FFL_silencing
Value_name_Peptide_abundance
Value_name_hEPO
Value_name_FVII_silencing
Value_name_GFP_delivery
Value_name_Discretized_luminescence
quantified_delivery

View File

@ -1,51 +1,37 @@
{
"loss_metrics": {
"loss": 2.5374555587768555,
"loss_size": 0.1886825958887736,
"loss_pdi": 0.45798932512601215,
"loss_ee": 0.829658567905426,
"loss_delivery": 0.4857304096221924,
"loss_biodist": 0.5346279243628184,
"loss_toxic": 0.04076674363265435,
"acc_pdi": 0.7862595419847328,
"acc_ee": 0.6793893129770993,
"acc_toxic": 0.9801980198019802
"loss": 2.8661977450052896,
"loss_size": 0.44916408757368725,
"loss_pdi": 0.5041926403840383,
"loss_ee": 0.9021427234013876,
"loss_delivery": 0.5761533578236898,
"loss_biodist": 0.4019051690896352,
"loss_toxic": 0.03263980595511384,
"acc_pdi": 0.7633587786259542,
"acc_ee": 0.6641221374045801,
"acc_toxic": 0.9702970297029703
},
"detailed_metrics": {
"size": {
"mse": 0.1669999969286325,
"rmse": 0.4086563310761654,
"mae": 0.26111859684375066,
"r2": 0.2149270281561566
"mse": 0.41126506251447736,
"rmse": 0.6412995107704959,
"mae": 0.41415552388095633,
"r2": -0.9333718010891026
},
"delivery": {
"mse": 0.5193460523366603,
"rmse": 0.7206566813238189,
"mae": 0.4828052782115008,
"r2": 0.37299826459145
"mse": 0.6277965050686476,
"rmse": 0.7923361061245711,
"mae": 0.5387302115022443,
"r2": 0.24206702565575944
},
"pdi": {
"accuracy": 0.7862595419847328,
"precision": 0.7282763532763532,
"recall": 0.6907738095238095,
"f1": 0.7041935483870968
"accuracy": 0.7633587786259542
},
"ee": {
"accuracy": 0.6793893129770993,
"precision": 0.612247574088644,
"recall": 0.6062951496388029,
"f1": 0.6069449904342585
"accuracy": 0.6641221374045801
},
"toxic": {
"accuracy": 0.9801980198019802,
"precision": 0.5,
"recall": 0.4900990099009901,
"f1": 0.495
},
"biodist": {
"n_samples": 101,
"kl_divergence": 0.2931957937514963,
"js_divergence": 0.07706768601895059
"accuracy": 0.9702970297029703
}
}
}

View File

@ -217,31 +217,15 @@ def test(
"""
import json
import numpy as np
from scipy.special import rel_entr
from sklearn.metrics import (
mean_squared_error,
mean_absolute_error,
r2_score,
accuracy_score,
precision_score,
recall_score,
f1_score,
classification_report,
)
from lnp_ml.modeling.trainer import validate
def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
"""计算 KL 散度 KL(p || q)"""
p = np.clip(p, eps, 1.0)
q = np.clip(q, eps, 1.0)
return float(np.sum(rel_entr(p, q), axis=-1).mean())
def js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
"""计算 JS 散度"""
p = np.clip(p, eps, 1.0)
q = np.clip(q, eps, 1.0)
m = 0.5 * (p + q)
return float(0.5 * (np.sum(rel_entr(p, m), axis=-1) + np.sum(rel_entr(q, m), axis=-1)).mean())
logger.info(f"Using device: {device}")
device_obj = torch.device(device)
@ -303,9 +287,6 @@ def test(
y_pred = np.array(predictions["pdi"])[mask]
results["detailed_metrics"]["pdi"] = {
"accuracy": float(accuracy_score(y_true, y_pred)),
"precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)),
"recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)),
"f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
}
# 分类指标EE
@ -318,9 +299,6 @@ def test(
y_pred = np.array(predictions["ee"])[mask]
results["detailed_metrics"]["ee"] = {
"accuracy": float(accuracy_score(y_true, y_pred)),
"precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)),
"recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)),
"f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
}
# 分类指标toxic
@ -331,28 +309,6 @@ def test(
y_pred = np.array(predictions["toxic"])[mask.values]
results["detailed_metrics"]["toxic"] = {
"accuracy": float(accuracy_score(y_true, y_pred)),
"precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)),
"recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)),
"f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
}
# 分布指标biodist
biodist_cols = [
"Biodistribution_lymph_nodes", "Biodistribution_heart", "Biodistribution_liver",
"Biodistribution_spleen", "Biodistribution_lung", "Biodistribution_kidney", "Biodistribution_muscle"
]
if all(c in test_df.columns for c in biodist_cols):
biodist_true = test_df[biodist_cols].values
biodist_pred = np.array(predictions["biodist"])
# mask: 有效样本是 sum > 0 且无 NaN
mask = (biodist_true.sum(axis=1) > 0) & (~np.isnan(biodist_true).any(axis=1))
if mask.any():
y_true = biodist_true[mask]
y_pred = biodist_pred[mask]
results["detailed_metrics"]["biodist"] = {
"n_samples": int(mask.sum()),
"kl_divergence": kl_divergence(y_true, y_pred),
"js_divergence": js_divergence(y_true, y_pred),
}
# 打印结果

View File

@ -271,7 +271,7 @@ def create_model(
@app.command()
def main(
data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
data_dir: Path = PROCESSED_DATA_DIR / "cv",
output_dir: Path = MODELS_DIR / "pretrain_cv",
# 模型参数
d_model: int = 256,
@ -322,7 +322,7 @@ def main(
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
logger.info("Please run 'make data_pretrain_cv' first to process CV data.")
logger.info("Please run 'make data_cv' first to process CV data.")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
@ -464,7 +464,7 @@ def main(
@app.command()
def test(
data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
data_dir: Path = PROCESSED_DATA_DIR / "cv",
model_dir: Path = MODELS_DIR / "pretrain_cv",
output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json",
batch_size: int = 64,

View File

@ -1,720 +0,0 @@
"""Cross-Validation 训练脚本:在 5-fold 内部数据上进行多任务训练"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from loguru import logger
from tqdm import tqdm
import typer
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import LNPDataset, collate_fn
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
from lnp_ml.modeling.trainer import (
train_epoch,
validate,
EarlyStopping,
LossWeights,
)
# MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
if not model_paths:
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
return [str(p) for p in model_paths]
app = typer.Typer()
def create_model(
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
mpnn_checkpoint: Optional[str] = None,
mpnn_ensemble_paths: Optional[List[str]] = None,
mpnn_device: str = "cpu",
) -> Union[LNPModel, LNPModelWithoutMPNN]:
"""创建模型(支持可选的 MPNN encoder"""
use_mpnn = mpnn_checkpoint is not None or mpnn_ensemble_paths is not None
if use_mpnn:
return LNPModel(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_checkpoint=mpnn_checkpoint,
mpnn_ensemble_paths=mpnn_ensemble_paths,
mpnn_device=mpnn_device,
)
else:
return LNPModelWithoutMPNN(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
)
def train_fold(
fold_idx: int,
train_loader: DataLoader,
val_loader: DataLoader,
model: nn.Module,
device: torch.device,
output_dir: Path,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 100,
patience: int = 15,
loss_weights: Optional[LossWeights] = None,
config: Optional[Dict] = None,
) -> Dict:
"""训练单个 fold"""
logger.info(f"\n{'='*60}")
logger.info(f"Training Fold {fold_idx}")
logger.info(f"{'='*60}")
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
early_stopping = EarlyStopping(patience=patience)
history = {"train": [], "val": []}
best_val_loss = float("inf")
best_state = None
for epoch in range(epochs):
# Train
train_metrics = train_epoch(model, train_loader, optimizer, device, loss_weights)
# Validate
val_metrics = validate(model, val_loader, device, loss_weights)
current_lr = optimizer.param_groups[0]["lr"]
# Log
logger.info(
f"Fold {fold_idx} Epoch {epoch+1}/{epochs} | "
f"Train Loss: {train_metrics['loss']:.4f} | "
f"Val Loss: {val_metrics['loss']:.4f} | "
f"LR: {current_lr:.2e}"
)
history["train"].append(train_metrics)
history["val"].append(val_metrics)
# Learning rate scheduling
scheduler.step(val_metrics["loss"])
# Save best model
if val_metrics["loss"] < best_val_loss:
best_val_loss = val_metrics["loss"]
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
logger.info(f" -> New best model (val_loss={best_val_loss:.4f})")
# Early stopping
if early_stopping(val_metrics["loss"]):
logger.info(f"Early stopping at epoch {epoch+1}")
break
# 保存最佳模型
fold_output_dir = output_dir / f"fold_{fold_idx}"
fold_output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = fold_output_dir / "model.pt"
torch.save({
"model_state_dict": best_state,
"config": config,
"best_val_loss": best_val_loss,
"fold_idx": fold_idx,
}, checkpoint_path)
logger.success(f"Saved fold {fold_idx} model to {checkpoint_path}")
# 保存训练历史
history_path = fold_output_dir / "history.json"
with open(history_path, "w") as f:
json.dump(history, f, indent=2)
return {
"fold_idx": fold_idx,
"best_val_loss": best_val_loss,
"epochs_trained": len(history["train"]),
"final_train_loss": history["train"][-1]["loss"] if history["train"] else 0,
}
@app.command()
def main(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
output_dir: Path = MODELS_DIR / "finetune_cv",
# 模型参数
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
# MPNN 参数(可选)
use_mpnn: bool = False,
mpnn_checkpoint: Optional[str] = None,
mpnn_ensemble_paths: Optional[str] = None,
mpnn_device: str = "cpu",
# 训练参数
batch_size: int = 32,
lr: float = 1e-4,
weight_decay: float = 1e-5,
epochs: int = 100,
patience: int = 15,
# 预训练权重加载
init_from_pretrain: Optional[Path] = None,
load_delivery_head: bool = True,
freeze_backbone: bool = False,
# 设备
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
基于 Cross-Validation 训练 LNP 模型多任务
5-fold 内部数据上训练 5 个模型
使用 --use-mpnn 启用 MPNN encoder
使用 --init-from-pretrain 从预训练 checkpoint 初始化
使用 --freeze-backbone 冻结 backbone只训练 heads
"""
logger.info(f"Using device: {device}")
device = torch.device(device)
# 查找所有 fold 目录
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
logger.info("Please run 'make data_cv' first to process CV data.")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
output_dir.mkdir(parents=True, exist_ok=True)
# 解析 MPNN 配置
ensemble_paths_list = None
if mpnn_ensemble_paths:
ensemble_paths_list = mpnn_ensemble_paths.split(",")
elif use_mpnn and mpnn_checkpoint is None:
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
ensemble_paths_list = find_mpnn_ensemble_paths()
logger.info(f"Found {len(ensemble_paths_list)} MPNN models")
enable_mpnn = mpnn_checkpoint is not None or ensemble_paths_list is not None
# 模型配置
config = {
"d_model": d_model,
"num_heads": num_heads,
"n_attn_layers": n_attn_layers,
"fusion_strategy": fusion_strategy,
"head_hidden_dim": head_hidden_dim,
"dropout": dropout,
"use_mpnn": enable_mpnn,
"lr": lr,
"weight_decay": weight_decay,
"batch_size": batch_size,
"epochs": epochs,
"patience": patience,
"init_from_pretrain": str(init_from_pretrain) if init_from_pretrain else None,
"freeze_backbone": freeze_backbone,
}
# 保存配置
config_path = output_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
logger.info(f"Saved config to {config_path}")
# 加载预训练权重(如果指定)
pretrain_state = None
if init_from_pretrain is not None:
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
pretrain_config = checkpoint.get("config", {})
if pretrain_config.get("d_model") != d_model:
logger.warning(
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
f"current={d_model}. Skipping pretrain loading."
)
else:
pretrain_state = checkpoint["model_state_dict"]
# 训练每个 fold
fold_results = []
for fold_dir in tqdm(fold_dirs, desc="Training folds"):
fold_idx = int(fold_dir.name.split("_")[1])
# 加载数据
train_df = pd.read_parquet(fold_dir / "train.parquet")
val_df = pd.read_parquet(fold_dir / "val.parquet")
logger.info(f"\nFold {fold_idx}: train={len(train_df)}, val={len(val_df)}")
# 创建 Dataset 和 DataLoader
train_dataset = LNPDataset(train_df)
val_dataset = LNPDataset(val_df)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 创建新模型(每个 fold 独立初始化)
model = create_model(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_checkpoint=mpnn_checkpoint,
mpnn_ensemble_paths=ensemble_paths_list,
mpnn_device=device.type,
)
# 加载预训练权重
if pretrain_state is not None:
model.load_pretrain_weights(
pretrain_state_dict=pretrain_state,
load_delivery_head=load_delivery_head,
strict=False,
)
logger.info(f"Loaded pretrain weights (backbone + delivery_head={load_delivery_head})")
# 冻结 backbone如果指定
if freeze_backbone:
frozen_count = 0
for name, param in model.named_parameters():
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
param.requires_grad = False
frozen_count += 1
logger.info(f"Frozen {frozen_count} parameter tensors")
# 打印模型信息(仅第一个 fold
if fold_idx == 0:
n_params_total = sum(p.numel() for p in model.parameters())
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
# 训练
result = train_fold(
fold_idx=fold_idx,
train_loader=train_loader,
val_loader=val_loader,
model=model,
device=device,
output_dir=output_dir,
lr=lr,
weight_decay=weight_decay,
epochs=epochs,
patience=patience,
config=config,
)
fold_results.append(result)
# 汇总结果
logger.info("\n" + "=" * 60)
logger.info("CROSS-VALIDATION TRAINING COMPLETE")
logger.info("=" * 60)
val_losses = [r["best_val_loss"] for r in fold_results]
logger.info(f"\n[Per-Fold Results]")
for r in fold_results:
logger.info(
f" Fold {r['fold_idx']}: "
f"Val Loss={r['best_val_loss']:.4f}, "
f"Epochs={r['epochs_trained']}"
)
logger.info(f"\n[Summary Statistics]")
logger.info(f" Val Loss: {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}")
# 保存 CV 结果
cv_results = {
"fold_results": fold_results,
"summary": {
"val_loss_mean": float(np.mean(val_losses)),
"val_loss_std": float(np.std(val_losses)),
},
"config": config,
}
results_path = output_dir / "cv_results.json"
with open(results_path, "w") as f:
json.dump(cv_results, f, indent=2)
logger.success(f"Saved CV results to {results_path}")
@app.command()
def test(
data_dir: Path = PROCESSED_DATA_DIR / "cv",
model_dir: Path = MODELS_DIR / "finetune_cv",
output_path: Path = MODELS_DIR / "finetune_cv" / "test_results.json",
batch_size: int = 64,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
在测试集上评估 CV 训练的模型
使用每个 fold 的模型在对应的测试集上评估然后汇总结果
"""
from scipy.special import rel_entr
from sklearn.metrics import (
mean_squared_error,
mean_absolute_error,
r2_score,
accuracy_score,
precision_score,
recall_score,
f1_score,
)
def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
"""计算 KL 散度 KL(p || q)"""
p = np.clip(p, eps, 1.0)
q = np.clip(q, eps, 1.0)
return float(np.sum(rel_entr(p, q), axis=-1).mean())
def js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
"""计算 JS 散度"""
p = np.clip(p, eps, 1.0)
q = np.clip(q, eps, 1.0)
m = 0.5 * (p + q)
return float(0.5 * (np.sum(rel_entr(p, m), axis=-1) + np.sum(rel_entr(q, m), axis=-1)).mean())
logger.info(f"Using device: {device}")
device = torch.device(device)
# 查找所有 fold 目录
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
if not fold_dirs:
logger.error(f"No fold_* directories found in {data_dir}")
raise typer.Exit(1)
logger.info(f"Found {len(fold_dirs)} folds")
fold_results = []
# 用于汇总所有 fold 的预测
all_preds = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
}
all_targets = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
}
for fold_dir in tqdm(fold_dirs, desc="Evaluating folds"):
fold_idx = int(fold_dir.name.split("_")[1])
model_path = model_dir / f"fold_{fold_idx}" / "model.pt"
test_path = fold_dir / "test.parquet"
if not model_path.exists():
logger.warning(f"Fold {fold_idx}: model not found at {model_path}, skipping")
continue
if not test_path.exists():
logger.warning(f"Fold {fold_idx}: test data not found at {test_path}, skipping")
continue
# 加载模型
checkpoint = torch.load(model_path, map_location=device)
config = checkpoint["config"]
use_mpnn = config.get("use_mpnn", False)
# 总是重新查找 MPNN 路径
if use_mpnn:
mpnn_paths = find_mpnn_ensemble_paths()
else:
mpnn_paths = None
model = create_model(
d_model=config["d_model"],
num_heads=config["num_heads"],
n_attn_layers=config["n_attn_layers"],
fusion_strategy=config["fusion_strategy"],
head_hidden_dim=config["head_hidden_dim"],
dropout=config["dropout"],
mpnn_ensemble_paths=mpnn_paths,
mpnn_device=device.type,
)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
model.eval()
# 加载测试数据
test_df = pd.read_parquet(test_path)
test_dataset = LNPDataset(test_df)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 收集当前 fold 的预测
fold_preds = {k: [] for k in all_preds.keys()}
fold_targets = {k: [] for k in all_targets.keys()}
with torch.no_grad():
pbar = tqdm(test_loader, desc=f"Fold {fold_idx} [Test]", leave=False)
for batch in pbar:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
targets = batch["targets"]
masks = batch["mask"]
outputs = model(smiles, tabular)
# Size
if "size" in masks and masks["size"].any():
mask = masks["size"]
fold_preds["size"].extend(
outputs["size"].squeeze(-1)[mask].cpu().numpy().tolist()
)
fold_targets["size"].extend(
targets["size"][mask].cpu().numpy().tolist()
)
# Delivery
if "delivery" in masks and masks["delivery"].any():
mask = masks["delivery"]
fold_preds["delivery"].extend(
outputs["delivery"].squeeze(-1)[mask].cpu().numpy().tolist()
)
fold_targets["delivery"].extend(
targets["delivery"][mask].cpu().numpy().tolist()
)
# PDI (classification)
if "pdi" in masks and masks["pdi"].any():
mask = masks["pdi"]
pdi_preds = outputs["pdi"][mask].argmax(dim=-1).cpu().numpy()
pdi_targets = targets["pdi"][mask].cpu().numpy()
fold_preds["pdi"].extend(pdi_preds.tolist())
fold_targets["pdi"].extend(pdi_targets.tolist())
# EE (classification)
if "ee" in masks and masks["ee"].any():
mask = masks["ee"]
ee_preds = outputs["ee"][mask].argmax(dim=-1).cpu().numpy()
ee_targets = targets["ee"][mask].cpu().numpy()
fold_preds["ee"].extend(ee_preds.tolist())
fold_targets["ee"].extend(ee_targets.tolist())
# Toxic (classification)
if "toxic" in masks and masks["toxic"].any():
mask = masks["toxic"]
toxic_preds = outputs["toxic"][mask].argmax(dim=-1).cpu().numpy()
toxic_targets = targets["toxic"][mask].cpu().numpy().astype(int)
fold_preds["toxic"].extend(toxic_preds.tolist())
fold_targets["toxic"].extend(toxic_targets.tolist())
# Biodist (distribution)
if "biodist" in masks and masks["biodist"].any():
mask = masks["biodist"]
biodist_preds = outputs["biodist"][mask].cpu().numpy()
biodist_targets = targets["biodist"][mask].cpu().numpy()
fold_preds["biodist"].extend(biodist_preds.tolist())
fold_targets["biodist"].extend(biodist_targets.tolist())
# 计算当前 fold 的指标
fold_metrics = {"fold_idx": fold_idx, "n_samples": len(test_df)}
# 回归任务指标
for task in ["size", "delivery"]:
if fold_preds[task]:
p = np.array(fold_preds[task])
t = np.array(fold_targets[task])
fold_metrics[task] = {
"n": len(p),
"rmse": float(np.sqrt(mean_squared_error(t, p))),
"mae": float(mean_absolute_error(t, p)),
"r2": float(r2_score(t, p)),
}
# 分类任务指标
for task in ["pdi", "ee", "toxic"]:
if fold_preds[task]:
p = np.array(fold_preds[task])
t = np.array(fold_targets[task])
fold_metrics[task] = {
"n": len(p),
"accuracy": float(accuracy_score(t, p)),
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
}
# 分布任务指标
if fold_preds["biodist"]:
p = np.array(fold_preds["biodist"])
t = np.array(fold_targets["biodist"])
fold_metrics["biodist"] = {
"n": len(p),
"kl_divergence": kl_divergence(t, p),
"js_divergence": js_divergence(t, p),
}
fold_results.append(fold_metrics)
# 汇总到全局
for task in all_preds.keys():
all_preds[task].extend(fold_preds[task])
all_targets[task].extend(fold_targets[task])
# 打印当前 fold 结果
log_parts = [f"Fold {fold_idx}: n={len(test_df)}"]
for task in ["delivery", "size"]:
if task in fold_metrics and isinstance(fold_metrics[task], dict):
log_parts.append(f"{task}_RMSE={fold_metrics[task]['rmse']:.4f}")
log_parts.append(f"{task}_R²={fold_metrics[task]['r2']:.4f}")
for task in ["pdi", "ee", "toxic"]:
if task in fold_metrics and isinstance(fold_metrics[task], dict):
log_parts.append(f"{task}_acc={fold_metrics[task]['accuracy']:.4f}")
log_parts.append(f"{task}_f1={fold_metrics[task]['f1']:.4f}")
if "biodist" in fold_metrics and isinstance(fold_metrics["biodist"], dict):
log_parts.append(f"biodist_KL={fold_metrics['biodist']['kl_divergence']:.4f}")
log_parts.append(f"biodist_JS={fold_metrics['biodist']['js_divergence']:.4f}")
logger.info(", ".join(log_parts))
# 计算跨 fold 汇总统计
summary_stats = {}
for task in ["size", "delivery"]:
rmses = [r[task]["rmse"] for r in fold_results if task in r and isinstance(r[task], dict)]
r2s = [r[task]["r2"] for r in fold_results if task in r and isinstance(r[task], dict)]
if rmses:
summary_stats[task] = {
"rmse_mean": float(np.mean(rmses)),
"rmse_std": float(np.std(rmses)),
"r2_mean": float(np.mean(r2s)),
"r2_std": float(np.std(r2s)),
}
for task in ["pdi", "ee", "toxic"]:
accs = [r[task]["accuracy"] for r in fold_results if task in r and isinstance(r[task], dict)]
f1s = [r[task]["f1"] for r in fold_results if task in r and isinstance(r[task], dict)]
if accs:
summary_stats[task] = {
"accuracy_mean": float(np.mean(accs)),
"accuracy_std": float(np.std(accs)),
"f1_mean": float(np.mean(f1s)),
"f1_std": float(np.std(f1s)),
}
# 分布任务汇总
kls = [r["biodist"]["kl_divergence"] for r in fold_results if "biodist" in r and isinstance(r["biodist"], dict)]
jss = [r["biodist"]["js_divergence"] for r in fold_results if "biodist" in r and isinstance(r["biodist"], dict)]
if kls:
summary_stats["biodist"] = {
"kl_mean": float(np.mean(kls)),
"kl_std": float(np.std(kls)),
"js_mean": float(np.mean(jss)),
"js_std": float(np.std(jss)),
}
# 计算整体 pooled 指标
overall = {}
for task in ["size", "delivery"]:
if all_preds[task]:
p = np.array(all_preds[task])
t = np.array(all_targets[task])
overall[task] = {
"n_samples": len(p),
"mse": float(mean_squared_error(t, p)),
"rmse": float(np.sqrt(mean_squared_error(t, p))),
"mae": float(mean_absolute_error(t, p)),
"r2": float(r2_score(t, p)),
}
for task in ["pdi", "ee", "toxic"]:
if all_preds[task]:
p = np.array(all_preds[task])
t = np.array(all_targets[task])
overall[task] = {
"n_samples": len(p),
"accuracy": float(accuracy_score(t, p)),
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
}
# 分布任务
if all_preds["biodist"]:
p = np.array(all_preds["biodist"])
t = np.array(all_targets["biodist"])
overall["biodist"] = {
"n_samples": len(p),
"kl_divergence": kl_divergence(t, p),
"js_divergence": js_divergence(t, p),
}
# 打印汇总结果
logger.info("\n" + "=" * 60)
logger.info("CV TEST EVALUATION RESULTS")
logger.info("=" * 60)
logger.info(f"\n[Summary Statistics (across {len(fold_results)} folds)]")
for task, stats in summary_stats.items():
if "rmse_mean" in stats:
logger.info(f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}")
elif "accuracy_mean" in stats:
logger.info(f" {task}: Accuracy={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}, F1={stats['f1_mean']:.4f}±{stats['f1_std']:.4f}")
elif "kl_mean" in stats:
logger.info(f" {task}: KL={stats['kl_mean']:.4f}±{stats['kl_std']:.4f}, JS={stats['js_mean']:.4f}±{stats['js_std']:.4f}")
logger.info(f"\n[Overall (all samples pooled)]")
for task, metrics in overall.items():
if "rmse" in metrics:
logger.info(f" {task} (n={metrics['n_samples']}): RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, R²={metrics['r2']:.4f}")
elif "accuracy" in metrics:
logger.info(f" {task} (n={metrics['n_samples']}): Accuracy={metrics['accuracy']:.4f}, Precision={metrics['precision']:.4f}, Recall={metrics['recall']:.4f}, F1={metrics['f1']:.4f}")
elif "kl_divergence" in metrics:
logger.info(f" {task} (n={metrics['n_samples']}): KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}")
# 保存结果
results = {
"fold_results": fold_results,
"summary_stats": summary_stats,
"overall": overall,
}
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
logger.success(f"\nSaved test results to {output_path}")
if __name__ == "__main__":
app()

View File

@ -1,16 +0,0 @@
{
"d_model": 256,
"num_heads": 8,
"n_attn_layers": 4,
"fusion_strategy": "attention",
"head_hidden_dim": 128,
"dropout": 0.1,
"use_mpnn": true,
"lr": 0.0001,
"weight_decay": 1e-05,
"batch_size": 32,
"epochs": 100,
"patience": 15,
"init_from_pretrain": null,
"freeze_backbone": false
}

View File

@ -1,54 +0,0 @@
{
"fold_results": [
{
"fold_idx": 0,
"best_val_loss": 5.7676777839660645,
"epochs_trained": 24,
"final_train_loss": 1.4942118644714355
},
{
"fold_idx": 1,
"best_val_loss": 8.418675899505615,
"epochs_trained": 20,
"final_train_loss": 1.4902493238449097
},
{
"fold_idx": 2,
"best_val_loss": 3.5122547830854143,
"epochs_trained": 25,
"final_train_loss": 1.7609570423762004
},
{
"fold_idx": 3,
"best_val_loss": 3.165306806564331,
"epochs_trained": 21,
"final_train_loss": 2.0073827385902403
},
{
"fold_idx": 4,
"best_val_loss": 2.996154228846232,
"epochs_trained": 18,
"final_train_loss": 1.9732873006300493
}
],
"summary": {
"val_loss_mean": 4.772013900393532,
"val_loss_std": 2.0790222989111475
},
"config": {
"d_model": 256,
"num_heads": 8,
"n_attn_layers": 4,
"fusion_strategy": "attention",
"head_hidden_dim": 128,
"dropout": 0.1,
"use_mpnn": true,
"lr": 0.0001,
"weight_decay": 1e-05,
"batch_size": 32,
"epochs": 100,
"patience": 15,
"init_from_pretrain": null,
"freeze_backbone": false
}
}

View File

@ -1,510 +0,0 @@
{
"train": [
{
"loss": 17.65872812271118,
"loss_size": 12.601411867141724,
"loss_pdi": 1.3666706204414367,
"loss_ee": 1.0830313920974732,
"loss_delivery": 0.5962779104709626,
"loss_biodist": 1.3918164849281311,
"loss_toxic": 0.6195200622081757
},
{
"loss": 5.925264883041382,
"loss_size": 1.8580878481268883,
"loss_pdi": 1.1011681258678436,
"loss_ee": 0.971046245098114,
"loss_delivery": 0.5075224950909615,
"loss_biodist": 1.1051940202713013,
"loss_toxic": 0.38224617540836336
},
{
"loss": 3.4781792640686033,
"loss_size": 0.23610344529151917,
"loss_pdi": 0.8137399554252625,
"loss_ee": 0.9135127127170563,
"loss_delivery": 0.4596045270562172,
"loss_biodist": 0.8695587992668152,
"loss_toxic": 0.18565986081957817
},
{
"loss": 2.9488561868667604,
"loss_size": 0.23130029290914536,
"loss_pdi": 0.644479614496231,
"loss_ee": 0.8721524059772492,
"loss_delivery": 0.4146773874759674,
"loss_biodist": 0.646893310546875,
"loss_toxic": 0.13935319259762763
},
{
"loss": 2.6432241678237913,
"loss_size": 0.16843259893357754,
"loss_pdi": 0.5857123643159866,
"loss_ee": 0.8315786123275757,
"loss_delivery": 0.4049036353826523,
"loss_biodist": 0.5410242855548859,
"loss_toxic": 0.11157271154224872
},
{
"loss": 2.461507487297058,
"loss_size": 0.18602822050452233,
"loss_pdi": 0.5872043997049332,
"loss_ee": 0.8179578661918641,
"loss_delivery": 0.32779163047671317,
"loss_biodist": 0.45097417533397677,
"loss_toxic": 0.09155115596950054
},
{
"loss": 2.3792370796203612,
"loss_size": 0.2090120367705822,
"loss_pdi": 0.5358257800340652,
"loss_ee": 0.8088949501514435,
"loss_delivery": 0.3434994474053383,
"loss_biodist": 0.40993946194648745,
"loss_toxic": 0.07206540685147048
},
{
"loss": 2.207099366188049,
"loss_size": 0.1589151345193386,
"loss_pdi": 0.5283154606819153,
"loss_ee": 0.7723551869392395,
"loss_delivery": 0.35645291954278946,
"loss_biodist": 0.3404483631253242,
"loss_toxic": 0.05061229532584548
},
{
"loss": 2.1428971529006957,
"loss_size": 0.19335013553500174,
"loss_pdi": 0.5021985083818435,
"loss_ee": 0.7642539083957672,
"loss_delivery": 0.31821031123399734,
"loss_biodist": 0.32588216066360476,
"loss_toxic": 0.03900211993604898
},
{
"loss": 1.9874909400939942,
"loss_size": 0.1736245721578598,
"loss_pdi": 0.46206980347633364,
"loss_ee": 0.7373365700244904,
"loss_delivery": 0.29703493416309357,
"loss_biodist": 0.2863417714834213,
"loss_toxic": 0.031083252932876348
},
{
"loss": 1.9297520160675048,
"loss_size": 0.1635374441742897,
"loss_pdi": 0.4737923800945282,
"loss_ee": 0.7171129584312439,
"loss_delivery": 0.28808903992176055,
"loss_biodist": 0.25874830335378646,
"loss_toxic": 0.028471904620528222
},
{
"loss": 1.8647576332092286,
"loss_size": 0.14790172204375268,
"loss_pdi": 0.4427785277366638,
"loss_ee": 0.7089932143688202,
"loss_delivery": 0.30143058970570563,
"loss_biodist": 0.24234647750854493,
"loss_toxic": 0.021307120053097605
},
{
"loss": 1.7996623039245605,
"loss_size": 0.1429538145661354,
"loss_pdi": 0.45114057660102846,
"loss_ee": 0.681770408153534,
"loss_delivery": 0.2735618159174919,
"loss_biodist": 0.2338838443160057,
"loss_toxic": 0.01635184111073613
},
{
"loss": 1.7303769707679748,
"loss_size": 0.13725369721651076,
"loss_pdi": 0.43492600619792937,
"loss_ee": 0.6648448914289474,
"loss_delivery": 0.2714417055249214,
"loss_biodist": 0.20898159295320512,
"loss_toxic": 0.012929048202931882
},
{
"loss": 1.702065145969391,
"loss_size": 0.1783118523657322,
"loss_pdi": 0.4118753671646118,
"loss_ee": 0.640222480893135,
"loss_delivery": 0.2610591858625412,
"loss_biodist": 0.20058825612068176,
"loss_toxic": 0.01000797227025032
},
{
"loss": 1.6243244886398316,
"loss_size": 0.1371393844485283,
"loss_pdi": 0.3978125751018524,
"loss_ee": 0.6315451622009277,
"loss_delivery": 0.2618463449180126,
"loss_biodist": 0.18574777096509934,
"loss_toxic": 0.010233237966895103
},
{
"loss": 1.645119547843933,
"loss_size": 0.13622624576091766,
"loss_pdi": 0.4013118803501129,
"loss_ee": 0.639850401878357,
"loss_delivery": 0.2615354858338833,
"loss_biodist": 0.19717498123645782,
"loss_toxic": 0.009020529384724797
},
{
"loss": 1.5792422771453858,
"loss_size": 0.12063037976622581,
"loss_pdi": 0.40477685928344725,
"loss_ee": 0.6168571084737777,
"loss_delivery": 0.23877703920006751,
"loss_biodist": 0.1887524366378784,
"loss_toxic": 0.009448455832898616
},
{
"loss": 1.5701380014419555,
"loss_size": 0.12370488420128822,
"loss_pdi": 0.3944096490740776,
"loss_ee": 0.6204680263996124,
"loss_delivery": 0.2499392546713352,
"loss_biodist": 0.1741167649626732,
"loss_toxic": 0.00749938020016998
},
{
"loss": 1.5445807576179504,
"loss_size": 0.12085893377661705,
"loss_pdi": 0.4022176057100296,
"loss_ee": 0.6029386401176453,
"loss_delivery": 0.2460342638194561,
"loss_biodist": 0.16601160615682603,
"loss_toxic": 0.006519717467017472
},
{
"loss": 1.4764926195144654,
"loss_size": 0.11393929794430732,
"loss_pdi": 0.3614879995584488,
"loss_ee": 0.5874974340200424,
"loss_delivery": 0.2382828861474991,
"loss_biodist": 0.168075630068779,
"loss_toxic": 0.00720936032012105
},
{
"loss": 1.4663256525993347,
"loss_size": 0.10480817258358002,
"loss_pdi": 0.3699364930391312,
"loss_ee": 0.591068571805954,
"loss_delivery": 0.23481545299291612,
"loss_biodist": 0.1582734301686287,
"loss_toxic": 0.007423530006781221
},
{
"loss": 1.4797919273376465,
"loss_size": 0.11906521767377853,
"loss_pdi": 0.3831163257360458,
"loss_ee": 0.5810098886489868,
"loss_delivery": 0.22465722858905793,
"loss_biodist": 0.16469249799847602,
"loss_toxic": 0.007250743336044252
},
{
"loss": 1.4942118644714355,
"loss_size": 0.11249525547027588,
"loss_pdi": 0.3718418627977371,
"loss_ee": 0.5973137259483338,
"loss_delivery": 0.23963096588850022,
"loss_biodist": 0.16598810032010078,
"loss_toxic": 0.006941930414177478
}
],
"val": [
{
"loss": 13.683866500854492,
"loss_size": 5.657964706420898,
"loss_pdi": 1.1590962409973145,
"loss_ee": 1.0155898332595825,
"loss_delivery": 4.1429033279418945,
"loss_biodist": 1.128843069076538,
"loss_toxic": 0.579468846321106,
"acc_pdi": 0.7407407407407407,
"acc_ee": 0.6296296296296297,
"acc_toxic": 1.0
},
{
"loss": 7.161172866821289,
"loss_size": 0.1799931526184082,
"loss_pdi": 0.8303115963935852,
"loss_ee": 0.942605197429657,
"loss_delivery": 3.986294984817505,
"loss_biodist": 1.022797703742981,
"loss_toxic": 0.19917015731334686,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.554836273193359,
"loss_size": 0.04609166830778122,
"loss_pdi": 0.4924769997596741,
"loss_ee": 0.965587317943573,
"loss_delivery": 3.978637933731079,
"loss_biodist": 1.0135102272033691,
"loss_toxic": 0.05853228643536568,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.843129634857178,
"loss_size": 0.07650057226419449,
"loss_pdi": 0.43551138043403625,
"loss_ee": 0.9353340864181519,
"loss_delivery": 4.557775974273682,
"loss_biodist": 0.7909315228462219,
"loss_toxic": 0.047075945883989334,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.711758613586426,
"loss_size": 0.04316325858235359,
"loss_pdi": 0.41873815655708313,
"loss_ee": 1.0096691846847534,
"loss_delivery": 4.517927169799805,
"loss_biodist": 0.6788683533668518,
"loss_toxic": 0.04339226707816124,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.905030250549316,
"loss_size": 0.045318666845560074,
"loss_pdi": 0.38593801856040955,
"loss_ee": 1.0019593238830566,
"loss_delivery": 4.807835578918457,
"loss_biodist": 0.6247215867042542,
"loss_toxic": 0.039257097989320755,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.417820930480957,
"loss_size": 0.05034356936812401,
"loss_pdi": 0.4149726331233978,
"loss_ee": 0.9869357943534851,
"loss_delivery": 4.405001640319824,
"loss_biodist": 0.533240556716919,
"loss_toxic": 0.02732720412313938,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.608631610870361,
"loss_size": 0.05222579464316368,
"loss_pdi": 0.4375711679458618,
"loss_ee": 1.0041171312332153,
"loss_delivery": 4.578192234039307,
"loss_biodist": 0.5125234723091125,
"loss_toxic": 0.02400212176144123,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 5.7676777839660645,
"loss_size": 0.09589201211929321,
"loss_pdi": 0.3261733949184418,
"loss_ee": 0.9482788443565369,
"loss_delivery": 3.856112003326416,
"loss_biodist": 0.5298716425895691,
"loss_toxic": 0.011350298300385475,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.920990943908691,
"loss_size": 0.05388057231903076,
"loss_pdi": 0.39705148339271545,
"loss_ee": 0.990842878818512,
"loss_delivery": 5.025243282318115,
"loss_biodist": 0.4346938133239746,
"loss_toxic": 0.019278930500149727,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 5.798760890960693,
"loss_size": 0.09857960045337677,
"loss_pdi": 0.33329641819000244,
"loss_ee": 0.9614524245262146,
"loss_delivery": 4.000489711761475,
"loss_biodist": 0.39874210953712463,
"loss_toxic": 0.006200834643095732,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.575327396392822,
"loss_size": 0.054699063301086426,
"loss_pdi": 0.33702051639556885,
"loss_ee": 0.9436452388763428,
"loss_delivery": 4.817119121551514,
"loss_biodist": 0.41582298278808594,
"loss_toxic": 0.007020703982561827,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 5.989306449890137,
"loss_size": 0.09009546041488647,
"loss_pdi": 0.3044246733188629,
"loss_ee": 1.0130207538604736,
"loss_delivery": 4.140576362609863,
"loss_biodist": 0.4378862977027893,
"loss_toxic": 0.0033026484306901693,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.383339881896973,
"loss_size": 0.1530081033706665,
"loss_pdi": 0.29700207710266113,
"loss_ee": 0.9943283796310425,
"loss_delivery": 4.564785480499268,
"loss_biodist": 0.37085360288619995,
"loss_toxic": 0.003362649120390415,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.233416557312012,
"loss_size": 0.1473817676305771,
"loss_pdi": 0.2754640281200409,
"loss_ee": 0.9803684949874878,
"loss_delivery": 4.443488597869873,
"loss_biodist": 0.38424524664878845,
"loss_toxic": 0.0024682653602212667,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.094257354736328,
"loss_size": 0.10127364844083786,
"loss_pdi": 0.2960923910140991,
"loss_ee": 1.0121080875396729,
"loss_delivery": 4.2689008712768555,
"loss_biodist": 0.4132467210292816,
"loss_toxic": 0.0026356647722423077,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.0315470695495605,
"loss_size": 0.13236114382743835,
"loss_pdi": 0.29554903507232666,
"loss_ee": 0.9912998080253601,
"loss_delivery": 4.2240777015686035,
"loss_biodist": 0.3861069977283478,
"loss_toxic": 0.002152827335521579,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5925925925925926,
"acc_toxic": 1.0
},
{
"loss": 6.13291597366333,
"loss_size": 0.10603927820920944,
"loss_pdi": 0.30880627036094666,
"loss_ee": 1.0417256355285645,
"loss_delivery": 4.337818622589111,
"loss_biodist": 0.33598482608795166,
"loss_toxic": 0.002541647758334875,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 5.918347358703613,
"loss_size": 0.11423231661319733,
"loss_pdi": 0.2779754102230072,
"loss_ee": 1.023812174797058,
"loss_delivery": 4.137387275695801,
"loss_biodist": 0.36283549666404724,
"loss_toxic": 0.002104171784594655,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.354115962982178,
"loss_size": 0.1212025135755539,
"loss_pdi": 0.2848753333091736,
"loss_ee": 1.031553030014038,
"loss_delivery": 4.554471969604492,
"loss_biodist": 0.3598195016384125,
"loss_toxic": 0.0021939175203442574,
"acc_pdi": 0.8518518518518519,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 5.881389141082764,
"loss_size": 0.1030399352312088,
"loss_pdi": 0.2791188657283783,
"loss_ee": 1.0205037593841553,
"loss_delivery": 4.111578464508057,
"loss_biodist": 0.3657107949256897,
"loss_toxic": 0.0014369667042046785,
"acc_pdi": 0.8518518518518519,
"acc_ee": 0.5555555555555556,
"acc_toxic": 1.0
},
{
"loss": 6.1028852462768555,
"loss_size": 0.10241233557462692,
"loss_pdi": 0.300007700920105,
"loss_ee": 1.0756882429122925,
"loss_delivery": 4.258440971374512,
"loss_biodist": 0.3646480441093445,
"loss_toxic": 0.0016881312476471066,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.128824234008789,
"loss_size": 0.1437627077102661,
"loss_pdi": 0.29325851798057556,
"loss_ee": 1.0818182229995728,
"loss_delivery": 4.236568450927734,
"loss_biodist": 0.3719424605369568,
"loss_toxic": 0.001474093529395759,
"acc_pdi": 0.8888888888888888,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
},
{
"loss": 6.055476188659668,
"loss_size": 0.13312266767024994,
"loss_pdi": 0.28571468591690063,
"loss_ee": 1.066524624824524,
"loss_delivery": 4.214193820953369,
"loss_biodist": 0.35442692041397095,
"loss_toxic": 0.0014939504908397794,
"acc_pdi": 0.8518518518518519,
"acc_ee": 0.5185185185185185,
"acc_toxic": 1.0
}
]
}

Binary file not shown.

View File

@ -1,426 +0,0 @@
{
"train": [
{
"loss": 21.963233947753906,
"loss_size": 16.82633171081543,
"loss_pdi": 1.2230936765670777,
"loss_ee": 1.0703922033309936,
"loss_delivery": 1.0690569162368775,
"loss_biodist": 1.1534382343292235,
"loss_toxic": 0.6209211587905884
},
{
"loss": 13.145495796203614,
"loss_size": 8.676862716674805,
"loss_pdi": 1.0655134558677672,
"loss_ee": 0.8999906063079834,
"loss_delivery": 0.8303895950317383,
"loss_biodist": 1.122723388671875,
"loss_toxic": 0.5500160694122315
},
{
"loss": 7.351448345184326,
"loss_size": 3.415665292739868,
"loss_pdi": 0.8565655469894409,
"loss_ee": 0.7837236523628235,
"loss_delivery": 0.8804788589477539,
"loss_biodist": 1.011645209789276,
"loss_toxic": 0.40336963534355164
},
{
"loss": 4.39948205947876,
"loss_size": 0.9713698267936707,
"loss_pdi": 0.6989291191101075,
"loss_ee": 0.6805540442466735,
"loss_delivery": 0.7624839186668396,
"loss_biodist": 0.9798830866813659,
"loss_toxic": 0.3062621414661407
},
{
"loss": 3.375754451751709,
"loss_size": 0.24608666747808455,
"loss_pdi": 0.5557448148727417,
"loss_ee": 0.6684133768081665,
"loss_delivery": 0.7611681580543518,
"loss_biodist": 0.919653308391571,
"loss_toxic": 0.22468801140785216
},
{
"loss": 2.9307605743408205,
"loss_size": 0.1106911577284336,
"loss_pdi": 0.5004462003707886,
"loss_ee": 0.6227471172809601,
"loss_delivery": 0.6758030593395233,
"loss_biodist": 0.8190896153450012,
"loss_toxic": 0.20198351740837098
},
{
"loss": 2.731675052642822,
"loss_size": 0.13740637749433518,
"loss_pdi": 0.4836215674877167,
"loss_ee": 0.5896897256374359,
"loss_delivery": 0.5866121172904968,
"loss_biodist": 0.7556124567985535,
"loss_toxic": 0.17873288169503213
},
{
"loss": 2.4887039184570314,
"loss_size": 0.12009606957435608,
"loss_pdi": 0.4361336886882782,
"loss_ee": 0.597134268283844,
"loss_delivery": 0.5648026138544082,
"loss_biodist": 0.6326960444450378,
"loss_toxic": 0.13784122765064238
},
{
"loss": 2.1680586099624635,
"loss_size": 0.12401954531669616,
"loss_pdi": 0.40216060280799865,
"loss_ee": 0.5528951227664948,
"loss_delivery": 0.42899617552757263,
"loss_biodist": 0.5442585527896882,
"loss_toxic": 0.1157285787165165
},
{
"loss": 2.1059993267059327,
"loss_size": 0.13299092650413513,
"loss_pdi": 0.38143277168273926,
"loss_ee": 0.5274551689624787,
"loss_delivery": 0.47739412933588027,
"loss_biodist": 0.4953398108482361,
"loss_toxic": 0.0913865402340889
},
{
"loss": 1.9570286750793457,
"loss_size": 0.1426382303237915,
"loss_pdi": 0.38325140476226804,
"loss_ee": 0.49524895548820497,
"loss_delivery": 0.42715947031974794,
"loss_biodist": 0.4287752747535706,
"loss_toxic": 0.07995530962944031
},
{
"loss": 1.8469573497772216,
"loss_size": 0.14165955781936646,
"loss_pdi": 0.36685559153556824,
"loss_ee": 0.4988661766052246,
"loss_delivery": 0.36661114990711213,
"loss_biodist": 0.39747334718704225,
"loss_toxic": 0.07549156174063683
},
{
"loss": 1.6980855226516725,
"loss_size": 0.11332993358373641,
"loss_pdi": 0.350938493013382,
"loss_ee": 0.47553136944770813,
"loss_delivery": 0.30049399137496946,
"loss_biodist": 0.3953311860561371,
"loss_toxic": 0.062460555136203764
},
{
"loss": 1.743706512451172,
"loss_size": 0.12467859983444214,
"loss_pdi": 0.3706244468688965,
"loss_ee": 0.4802402436733246,
"loss_delivery": 0.36484516113996507,
"loss_biodist": 0.3557030588388443,
"loss_toxic": 0.04761496149003506
},
{
"loss": 1.7470735549926757,
"loss_size": 0.10215002745389938,
"loss_pdi": 0.3553147315979004,
"loss_ee": 0.4548905730247498,
"loss_delivery": 0.4480485826730728,
"loss_biodist": 0.3265932142734528,
"loss_toxic": 0.06007647253572941
},
{
"loss": 1.7687433004379272,
"loss_size": 0.10528398901224137,
"loss_pdi": 0.35497177839279176,
"loss_ee": 0.4946293234825134,
"loss_delivery": 0.44853600263595583,
"loss_biodist": 0.3113987982273102,
"loss_toxic": 0.053923492506146434
},
{
"loss": 1.573294997215271,
"loss_size": 0.11145550012588501,
"loss_pdi": 0.33941014409065245,
"loss_ee": 0.42823529839515684,
"loss_delivery": 0.34292849004268644,
"loss_biodist": 0.3095307767391205,
"loss_toxic": 0.04173475466668606
},
{
"loss": 1.482050108909607,
"loss_size": 0.13211917281150817,
"loss_pdi": 0.31831381320953367,
"loss_ee": 0.4258797198534012,
"loss_delivery": 0.26612227857112886,
"loss_biodist": 0.30344046354293824,
"loss_toxic": 0.03617466017603874
},
{
"loss": 1.5079625368118286,
"loss_size": 0.1129397764801979,
"loss_pdi": 0.3118207275867462,
"loss_ee": 0.4255594819784164,
"loss_delivery": 0.30544502288103104,
"loss_biodist": 0.31328115463256834,
"loss_toxic": 0.03891638442873955
},
{
"loss": 1.4902493238449097,
"loss_size": 0.09879767149686813,
"loss_pdi": 0.333440762758255,
"loss_ee": 0.430321592092514,
"loss_delivery": 0.3070627197623253,
"loss_biodist": 0.28984564244747163,
"loss_toxic": 0.030780918896198273
}
],
"val": [
{
"loss": 24.328961690266926,
"loss_size": 15.672358830769857,
"loss_pdi": 1.268057902654012,
"loss_ee": 1.0569811463356018,
"loss_delivery": 4.617272272706032,
"loss_biodist": 1.0806464751561482,
"loss_toxic": 0.633646289507548,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.5894736842105263,
"acc_toxic": 0.8939393939393939
},
{
"loss": 17.03301429748535,
"loss_size": 8.649629751841227,
"loss_pdi": 1.165820797284444,
"loss_ee": 0.9437925020853678,
"loss_delivery": 4.629274984200795,
"loss_biodist": 1.0683060089747112,
"loss_toxic": 0.5761909882227579,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 11.635572751363119,
"loss_size": 3.504341204961141,
"loss_pdi": 1.0855141083399455,
"loss_ee": 0.8674407601356506,
"loss_delivery": 4.705501407384872,
"loss_biodist": 1.0376905004183452,
"loss_toxic": 0.43508487939834595,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 9.058362166086832,
"loss_size": 1.0461570421854656,
"loss_pdi": 1.070031762123108,
"loss_ee": 0.8463932275772095,
"loss_delivery": 4.791346887747447,
"loss_biodist": 0.9781110286712646,
"loss_toxic": 0.32632239659627277,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.418675899505615,
"loss_size": 0.3764288102587064,
"loss_pdi": 1.0916812817255657,
"loss_ee": 0.8714254101117452,
"loss_delivery": 4.8696667949358625,
"loss_biodist": 0.9307892719904581,
"loss_toxic": 0.278684730331103,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.51748021443685,
"loss_size": 0.33909208327531815,
"loss_pdi": 1.103804111480713,
"loss_ee": 0.8707688599824905,
"loss_delivery": 5.0624091029167175,
"loss_biodist": 0.8743396997451782,
"loss_toxic": 0.2670666699608167,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.701509237289429,
"loss_size": 0.38883806640903157,
"loss_pdi": 1.0901564558347066,
"loss_ee": 0.8219001442193985,
"loss_delivery": 5.329233412941297,
"loss_biodist": 0.808117667833964,
"loss_toxic": 0.2632630293567975,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.602253516515097,
"loss_size": 0.399209912866354,
"loss_pdi": 1.035650501648585,
"loss_ee": 0.8119546920061111,
"loss_delivery": 5.297288862367471,
"loss_biodist": 0.8003136416276296,
"loss_toxic": 0.2578362462421258,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.610430796941122,
"loss_size": 0.3884888291358948,
"loss_pdi": 0.9680223266283671,
"loss_ee": 0.8063104202349981,
"loss_delivery": 5.504999443888664,
"loss_biodist": 0.7153328458468119,
"loss_toxic": 0.22727691816786924,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.894750118255615,
"loss_size": 0.38015256201227504,
"loss_pdi": 0.9849910040696462,
"loss_ee": 0.8192636320988337,
"loss_delivery": 5.8433875640233355,
"loss_biodist": 0.6525928874810537,
"loss_toxic": 0.21436312049627304,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.684672435124716,
"loss_size": 0.39142270882924396,
"loss_pdi": 0.9926454623540243,
"loss_ee": 0.8487897912661234,
"loss_delivery": 5.675399616360664,
"loss_biodist": 0.5763055086135864,
"loss_toxic": 0.2001086367915074,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.468807538350424,
"loss_size": 0.37035099665323895,
"loss_pdi": 0.9933059811592102,
"loss_ee": 0.8365495651960373,
"loss_delivery": 5.39086152613163,
"loss_biodist": 0.6555034021536509,
"loss_toxic": 0.22223659542699656,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.48760732014974,
"loss_size": 0.3547621878484885,
"loss_pdi": 1.008083571990331,
"loss_ee": 0.8507340376575788,
"loss_delivery": 5.329072058200836,
"loss_biodist": 0.7051869928836823,
"loss_toxic": 0.23976873668531576,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.534255782763163,
"loss_size": 0.35214799270033836,
"loss_pdi": 1.0083338419596355,
"loss_ee": 0.8703259030977885,
"loss_delivery": 5.4809657235940294,
"loss_biodist": 0.6066243648529053,
"loss_toxic": 0.21585797673712173,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6842105263157895,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.59092911084493,
"loss_size": 0.3520332872867584,
"loss_pdi": 0.9944024880727133,
"loss_ee": 0.8839219162861506,
"loss_delivery": 5.593439628680547,
"loss_biodist": 0.562449519832929,
"loss_toxic": 0.20468231476843357,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.581690510114035,
"loss_size": 0.3468632685641448,
"loss_pdi": 1.0153752664724986,
"loss_ee": 0.884696863591671,
"loss_delivery": 5.548932209610939,
"loss_biodist": 0.576594889163971,
"loss_toxic": 0.20922777770708004,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.60028068224589,
"loss_size": 0.34553587809205055,
"loss_pdi": 1.0314316948254902,
"loss_ee": 0.8696443388859431,
"loss_delivery": 5.513105024894078,
"loss_biodist": 0.6187789390484492,
"loss_toxic": 0.2217849005634586,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6947368421052632,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.721842130025228,
"loss_size": 0.3432792164385319,
"loss_pdi": 1.044082870086034,
"loss_ee": 0.8888355021675428,
"loss_delivery": 5.590167284011841,
"loss_biodist": 0.6300752113262812,
"loss_toxic": 0.22540184513976178,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6736842105263158,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.821967244148254,
"loss_size": 0.3423520748813947,
"loss_pdi": 1.0627215206623077,
"loss_ee": 0.9012102037668228,
"loss_delivery": 5.6443866689999895,
"loss_biodist": 0.6428664823373159,
"loss_toxic": 0.2284308553983768,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6631578947368421,
"acc_toxic": 0.8939393939393939
},
{
"loss": 8.798149506251017,
"loss_size": 0.3439513569076856,
"loss_pdi": 1.074403668443362,
"loss_ee": 0.9037297517061234,
"loss_delivery": 5.62445667386055,
"loss_biodist": 0.6279164751370748,
"loss_toxic": 0.22369086369872093,
"acc_pdi": 0.6105263157894737,
"acc_ee": 0.6631578947368421,
"acc_toxic": 0.8939393939393939
}
]
}

Binary file not shown.

View File

@ -1,531 +0,0 @@
{
"train": [
{
"loss": 22.972569465637207,
"loss_size": 17.22947899500529,
"loss_pdi": 1.3510672052701314,
"loss_ee": 1.0506827433904011,
"loss_delivery": 1.5721548050642014,
"loss_biodist": 1.1304322481155396,
"loss_toxic": 0.6387530366579691
},
{
"loss": 12.718681335449219,
"loss_size": 7.335077285766602,
"loss_pdi": 1.198062241077423,
"loss_ee": 0.97108127673467,
"loss_delivery": 1.639556477467219,
"loss_biodist": 1.0746847093105316,
"loss_toxic": 0.50021959344546
},
{
"loss": 6.867454210917155,
"loss_size": 2.3153140544891357,
"loss_pdi": 1.0159071584542592,
"loss_ee": 0.853508859872818,
"loss_delivery": 1.2862873176733653,
"loss_biodist": 1.0416639745235443,
"loss_toxic": 0.35477257271607715
},
{
"loss": 4.856432318687439,
"loss_size": 0.5409951706727346,
"loss_pdi": 0.8652523259321848,
"loss_ee": 0.7771940131982168,
"loss_delivery": 1.413562481602033,
"loss_biodist": 0.9977987806002299,
"loss_toxic": 0.26162934054931003
},
{
"loss": 4.253215591112773,
"loss_size": 0.2641367167234421,
"loss_pdi": 0.739859402179718,
"loss_ee": 0.7256686190764109,
"loss_delivery": 1.4241955528656642,
"loss_biodist": 0.8935903211434683,
"loss_toxic": 0.2057649294535319
},
{
"loss": 3.8961705764134726,
"loss_size": 0.2962125514944394,
"loss_pdi": 0.682400623957316,
"loss_ee": 0.6820215880870819,
"loss_delivery": 1.2787245536843936,
"loss_biodist": 0.7812575101852417,
"loss_toxic": 0.1755537080268065
},
{
"loss": 3.4790991942087808,
"loss_size": 0.3047281603018443,
"loss_pdi": 0.6409291823705038,
"loss_ee": 0.6178905169169108,
"loss_delivery": 1.1121559316913288,
"loss_biodist": 0.6434484819571177,
"loss_toxic": 0.15994682783881822
},
{
"loss": 3.2075613339742026,
"loss_size": 0.3421506683031718,
"loss_pdi": 0.5879766543706259,
"loss_ee": 0.5811398377021154,
"loss_delivery": 1.0462109719713528,
"loss_biodist": 0.520307645201683,
"loss_toxic": 0.12977550799647966
},
{
"loss": 2.861353278160095,
"loss_size": 0.2742840300003688,
"loss_pdi": 0.5437282969554266,
"loss_ee": 0.5531725088755289,
"loss_delivery": 0.9213679246604443,
"loss_biodist": 0.4499489863713582,
"loss_toxic": 0.11885150956610839
},
{
"loss": 2.6909215847651162,
"loss_size": 0.23881135260065398,
"loss_pdi": 0.5229279547929764,
"loss_ee": 0.5285524874925613,
"loss_delivery": 0.8911051253477732,
"loss_biodist": 0.4015616128842036,
"loss_toxic": 0.10796305599311988
},
{
"loss": 2.5927247206370034,
"loss_size": 0.27356760079662007,
"loss_pdi": 0.5166990955670675,
"loss_ee": 0.5059170673290888,
"loss_delivery": 0.8377179056406021,
"loss_biodist": 0.3519642899433772,
"loss_toxic": 0.10685871541500092
},
{
"loss": 2.3971973856290183,
"loss_size": 0.2688147674004237,
"loss_pdi": 0.4851151605447133,
"loss_ee": 0.47870688637097675,
"loss_delivery": 0.7584750155607859,
"loss_biodist": 0.3166690344611804,
"loss_toxic": 0.08941652067005634
},
{
"loss": 2.2271180947621665,
"loss_size": 0.2559296215573947,
"loss_pdi": 0.467803418636322,
"loss_ee": 0.4819647620121638,
"loss_delivery": 0.6487737223505974,
"loss_biodist": 0.2930952211221059,
"loss_toxic": 0.079551310899357
},
{
"loss": 2.1467134952545166,
"loss_size": 0.2658323546250661,
"loss_pdi": 0.47287177046140033,
"loss_ee": 0.4580538024504979,
"loss_delivery": 0.6110207016269366,
"loss_biodist": 0.26590356479088467,
"loss_toxic": 0.07303123424450557
},
{
"loss": 2.0699684421221414,
"loss_size": 0.23655260602633157,
"loss_pdi": 0.46446068088213605,
"loss_ee": 0.43884341915448505,
"loss_delivery": 0.5945644030968348,
"loss_biodist": 0.26856863250335056,
"loss_toxic": 0.06697871504972379
},
{
"loss": 2.012367367744446,
"loss_size": 0.20358355715870857,
"loss_pdi": 0.44864421089490253,
"loss_ee": 0.4260970900456111,
"loss_delivery": 0.6111055202782154,
"loss_biodist": 0.24829111248254776,
"loss_toxic": 0.07464585608492295
},
{
"loss": 1.9354575673739116,
"loss_size": 0.19155597686767578,
"loss_pdi": 0.43001438677310944,
"loss_ee": 0.4029633104801178,
"loss_delivery": 0.5866967861851057,
"loss_biodist": 0.26284457246462506,
"loss_toxic": 0.06138256782044967
},
{
"loss": 1.9248821139335632,
"loss_size": 0.19836385796467462,
"loss_pdi": 0.43165912727514905,
"loss_ee": 0.4223821411530177,
"loss_delivery": 0.5774712382505337,
"loss_biodist": 0.23008103668689728,
"loss_toxic": 0.06492467441906531
},
{
"loss": 1.7986130317052205,
"loss_size": 0.1977602814634641,
"loss_pdi": 0.4213625093301137,
"loss_ee": 0.3969506522019704,
"loss_delivery": 0.4972396679222584,
"loss_biodist": 0.22815552850564322,
"loss_toxic": 0.05714430411656698
},
{
"loss": 1.8008437156677246,
"loss_size": 0.20143492271502814,
"loss_pdi": 0.4257240394751231,
"loss_ee": 0.3939937750498454,
"loss_delivery": 0.4996156108876069,
"loss_biodist": 0.22945881386597952,
"loss_toxic": 0.05061656702309847
},
{
"loss": 1.8123606244723003,
"loss_size": 0.23175274084011713,
"loss_pdi": 0.41065867245197296,
"loss_ee": 0.38645289838314056,
"loss_delivery": 0.5105274474869171,
"loss_biodist": 0.22759289046128592,
"loss_toxic": 0.04537593169758717
},
{
"loss": 1.85766206185023,
"loss_size": 0.2237167110045751,
"loss_pdi": 0.4198872745037079,
"loss_ee": 0.39036936064561206,
"loss_delivery": 0.5356200908621153,
"loss_biodist": 0.2327907457947731,
"loss_toxic": 0.055277835965777435
},
{
"loss": 1.7299150824546814,
"loss_size": 0.16866947089632353,
"loss_pdi": 0.40182044357061386,
"loss_ee": 0.37123599648475647,
"loss_delivery": 0.5183743331581354,
"loss_biodist": 0.22459317495425543,
"loss_toxic": 0.04522162117063999
},
{
"loss": 1.8115381598472595,
"loss_size": 0.21021889025966325,
"loss_pdi": 0.3938516428073247,
"loss_ee": 0.3856282929579417,
"loss_delivery": 0.5463737193495035,
"loss_biodist": 0.22467835744222006,
"loss_toxic": 0.05078731415172418
},
{
"loss": 1.7609570423762004,
"loss_size": 0.20672637100021043,
"loss_pdi": 0.3868243644634883,
"loss_ee": 0.3775654385487239,
"loss_delivery": 0.5160359914104143,
"loss_biodist": 0.22730938345193863,
"loss_toxic": 0.04649555139864484
}
],
"val": [
{
"loss": 19.896042142595565,
"loss_size": 14.947636876787458,
"loss_pdi": 1.3514722074781145,
"loss_ee": 1.0372784308024816,
"loss_delivery": 0.5157596128327506,
"loss_biodist": 1.3665738276072912,
"loss_toxic": 0.6773212381771633,
"acc_pdi": 0.22564102564102564,
"acc_ee": 0.4512820512820513,
"acc_toxic": 0.7073170731707317
},
{
"loss": 10.277108192443848,
"loss_size": 5.728530270712716,
"loss_pdi": 1.2047701733452933,
"loss_ee": 1.013599353177207,
"loss_delivery": 0.5329383058207375,
"loss_biodist": 1.3288453817367554,
"loss_toxic": 0.4684244394302368,
"acc_pdi": 0.358974358974359,
"acc_ee": 0.4461538461538462,
"acc_toxic": 1.0
},
{
"loss": 4.947442190987723,
"loss_size": 0.9055690084184919,
"loss_pdi": 0.9325166344642639,
"loss_ee": 1.071527932371412,
"loss_delivery": 0.5525430504764829,
"loss_biodist": 1.2514750446592058,
"loss_toxic": 0.23381062703473227,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.072668756757464,
"loss_size": 0.18621658267719404,
"loss_pdi": 0.7640595691544669,
"loss_ee": 1.2155327456338065,
"loss_delivery": 0.5523517067943301,
"loss_biodist": 1.214196733066014,
"loss_toxic": 0.14031140506267548,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.053883859089443,
"loss_size": 0.24764748875583922,
"loss_pdi": 0.6859285916600909,
"loss_ee": 1.3019903557641166,
"loss_delivery": 0.5384655041354043,
"loss_biodist": 1.191110406603132,
"loss_toxic": 0.08874151536396571,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.9850601128169467,
"loss_size": 0.18243093735405377,
"loss_pdi": 0.6606386282614299,
"loss_ee": 1.2955879313605172,
"loss_delivery": 0.6222422846726009,
"loss_biodist": 1.1586603011403764,
"loss_toxic": 0.06550001353025436,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.6691461631229947,
"loss_size": 0.18682856378810747,
"loss_pdi": 0.6505168399640492,
"loss_ee": 1.2279302733285087,
"loss_delivery": 0.5776888344969068,
"loss_biodist": 0.963392470564161,
"loss_toxic": 0.06278917672378677,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.6420267650059293,
"loss_size": 0.1952320017984935,
"loss_pdi": 0.6510439600263324,
"loss_ee": 1.1954282522201538,
"loss_delivery": 0.7644538623946053,
"loss_biodist": 0.7812093198299408,
"loss_toxic": 0.05465935756053243,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.35384615384615387,
"acc_toxic": 1.0
},
{
"loss": 3.707563672746931,
"loss_size": 0.2168926394411496,
"loss_pdi": 0.6468359615121569,
"loss_ee": 1.2361225570951189,
"loss_delivery": 0.8388645563806806,
"loss_biodist": 0.7232611009052822,
"loss_toxic": 0.04558686592749187,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.37435897435897436,
"acc_toxic": 1.0
},
{
"loss": 3.5122547830854143,
"loss_size": 0.2614529473440988,
"loss_pdi": 0.6344352598701205,
"loss_ee": 1.2337199449539185,
"loss_delivery": 0.7557642417294639,
"loss_biodist": 0.5974612619195666,
"loss_toxic": 0.029421218537858555,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.38461538461538464,
"acc_toxic": 1.0
},
{
"loss": 4.085699932915824,
"loss_size": 0.2200212436062949,
"loss_pdi": 0.6225023801837649,
"loss_ee": 1.1934180770601546,
"loss_delivery": 1.4556497505732946,
"loss_biodist": 0.5685178296906608,
"loss_toxic": 0.02559054403432778,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 3.8555148329053606,
"loss_size": 0.25200862703578814,
"loss_pdi": 0.6161004496472222,
"loss_ee": 1.2155306509562902,
"loss_delivery": 1.1912458751882826,
"loss_biodist": 0.5578728743961879,
"loss_toxic": 0.022756420208939483,
"acc_pdi": 0.7128205128205128,
"acc_ee": 0.4153846153846154,
"acc_toxic": 1.0
},
{
"loss": 4.016897848674229,
"loss_size": 0.24823468178510666,
"loss_pdi": 0.6098369508981705,
"loss_ee": 1.2048260143824987,
"loss_delivery": 1.3509162408964974,
"loss_biodist": 0.5822246244975499,
"loss_toxic": 0.020859350051198686,
"acc_pdi": 0.717948717948718,
"acc_ee": 0.4153846153846154,
"acc_toxic": 1.0
},
{
"loss": 3.9899745328085765,
"loss_size": 0.27859273659331457,
"loss_pdi": 0.6051683936800275,
"loss_ee": 1.1875721216201782,
"loss_delivery": 1.369072552238192,
"loss_biodist": 0.5321473862443652,
"loss_toxic": 0.017421354805784568,
"acc_pdi": 0.717948717948718,
"acc_ee": 0.40512820512820513,
"acc_toxic": 1.0
},
{
"loss": 4.112551552908761,
"loss_size": 0.23502862347023828,
"loss_pdi": 0.6127274334430695,
"loss_ee": 1.2102909428732735,
"loss_delivery": 1.5001615626471383,
"loss_biodist": 0.5411617543016162,
"loss_toxic": 0.013181165459432773,
"acc_pdi": 0.7025641025641025,
"acc_ee": 0.4205128205128205,
"acc_toxic": 1.0
},
{
"loss": 4.031796966280256,
"loss_size": 0.27996559121779035,
"loss_pdi": 0.6218498295971325,
"loss_ee": 1.264663577079773,
"loss_delivery": 1.225053642477308,
"loss_biodist": 0.6304309921605247,
"loss_toxic": 0.009833223053387232,
"acc_pdi": 0.7128205128205128,
"acc_ee": 0.4153846153846154,
"acc_toxic": 1.0
},
{
"loss": 4.108100175857544,
"loss_size": 0.2646343271647181,
"loss_pdi": 0.6244613996573857,
"loss_ee": 1.2785721336092268,
"loss_delivery": 1.2817653375012534,
"loss_biodist": 0.6491539776325226,
"loss_toxic": 0.009512946475297213,
"acc_pdi": 0.7025641025641025,
"acc_ee": 0.4,
"acc_toxic": 1.0
},
{
"loss": 4.237571648188999,
"loss_size": 0.24608339369297028,
"loss_pdi": 0.6387959729347911,
"loss_ee": 1.263997495174408,
"loss_delivery": 1.3677963316440582,
"loss_biodist": 0.7113959235804421,
"loss_toxic": 0.00950257752888969,
"acc_pdi": 0.7025641025641025,
"acc_ee": 0.40512820512820513,
"acc_toxic": 1.0
},
{
"loss": 4.677373443331037,
"loss_size": 0.25485377439430784,
"loss_pdi": 0.6646602579525539,
"loss_ee": 1.2957249539239066,
"loss_delivery": 1.609152581010546,
"loss_biodist": 0.8443670613425118,
"loss_toxic": 0.008614770535911833,
"acc_pdi": 0.7025641025641025,
"acc_ee": 0.4,
"acc_toxic": 1.0
},
{
"loss": 4.661478791918073,
"loss_size": 0.22112409876925604,
"loss_pdi": 0.6840873499001775,
"loss_ee": 1.345719371523176,
"loss_delivery": 1.5750049437795366,
"loss_biodist": 0.8278610365731376,
"loss_toxic": 0.0076820029810603175,
"acc_pdi": 0.6974358974358974,
"acc_ee": 0.39487179487179486,
"acc_toxic": 1.0
},
{
"loss": 4.7805344717843195,
"loss_size": 0.23743291412081038,
"loss_pdi": 0.6911627639617238,
"loss_ee": 1.3796877009528024,
"loss_delivery": 1.6191245743206568,
"loss_biodist": 0.8459444258894239,
"loss_toxic": 0.00718201809961881,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.4,
"acc_toxic": 1.0
},
{
"loss": 5.050536905016218,
"loss_size": 0.2340430138366563,
"loss_pdi": 0.7009050281984466,
"loss_ee": 1.3757387569972448,
"loss_delivery": 1.760018629687173,
"loss_biodist": 0.9723425933292934,
"loss_toxic": 0.007489000846232686,
"acc_pdi": 0.7128205128205128,
"acc_ee": 0.39487179487179486,
"acc_toxic": 1.0
},
{
"loss": 5.172625916344779,
"loss_size": 0.23549717558281763,
"loss_pdi": 0.6980976568801063,
"loss_ee": 1.357451047216143,
"loss_delivery": 1.893591480595725,
"loss_biodist": 0.9802024279321943,
"loss_toxic": 0.0077861944612647805,
"acc_pdi": 0.717948717948718,
"acc_ee": 0.38974358974358975,
"acc_toxic": 1.0
},
{
"loss": 5.048826490129743,
"loss_size": 0.2420537450483867,
"loss_pdi": 0.7013733184763363,
"loss_ee": 1.353548560823713,
"loss_delivery": 1.7931698901312692,
"loss_biodist": 0.9509889696325574,
"loss_toxic": 0.007692053714501006,
"acc_pdi": 0.7128205128205128,
"acc_ee": 0.38974358974358975,
"acc_toxic": 1.0
},
{
"loss": 4.951304980686733,
"loss_size": 0.24394649054322923,
"loss_pdi": 0.7197230202811105,
"loss_ee": 1.3789095027106149,
"loss_delivery": 1.6561105762209212,
"loss_biodist": 0.9460095167160034,
"loss_toxic": 0.006605847106714334,
"acc_pdi": 0.7076923076923077,
"acc_ee": 0.38974358974358975,
"acc_toxic": 1.0
}
]
}

Binary file not shown.

View File

@ -1,447 +0,0 @@
{
"train": [
{
"loss": 19.317236709594727,
"loss_size": 14.108779287338256,
"loss_pdi": 1.2223044037818909,
"loss_ee": 1.1201724767684937,
"loss_delivery": 1.0989456713199615,
"loss_biodist": 1.2243955612182618,
"loss_toxic": 0.5426396489143371
},
{
"loss": 7.058736562728882,
"loss_size": 2.407193088531494,
"loss_pdi": 1.0396445631980895,
"loss_ee": 1.043432891368866,
"loss_delivery": 1.183875671029091,
"loss_biodist": 1.0435532987117768,
"loss_toxic": 0.3410369783639908
},
{
"loss": 4.30960967540741,
"loss_size": 0.3954987198114395,
"loss_pdi": 0.8135693371295929,
"loss_ee": 0.9796469569206238,
"loss_delivery": 1.0894740536808967,
"loss_biodist": 0.8407172799110413,
"loss_toxic": 0.1907033085823059
},
{
"loss": 3.814995551109314,
"loss_size": 0.3033123552799225,
"loss_pdi": 0.7126355469226837,
"loss_ee": 0.9527663111686706,
"loss_delivery": 1.0963637247681617,
"loss_biodist": 0.613499540090561,
"loss_toxic": 0.13641793020069598
},
{
"loss": 3.4925455331802366,
"loss_size": 0.31617238447070123,
"loss_pdi": 0.6569374144077301,
"loss_ee": 0.9212668299674988,
"loss_delivery": 1.0250366538763047,
"loss_biodist": 0.4527473896741867,
"loss_toxic": 0.12038486637175083
},
{
"loss": 3.255272912979126,
"loss_size": 0.27022836059331895,
"loss_pdi": 0.6289663434028625,
"loss_ee": 0.9047561466693879,
"loss_delivery": 0.9742588266730309,
"loss_biodist": 0.3838836058974266,
"loss_toxic": 0.09317965283989907
},
{
"loss": 3.281973719596863,
"loss_size": 0.26578598394989966,
"loss_pdi": 0.5881433010101318,
"loss_ee": 0.8700660288333892,
"loss_delivery": 1.1519657298922539,
"loss_biodist": 0.3240764126181602,
"loss_toxic": 0.08193630240857601
},
{
"loss": 2.7810576915740968,
"loss_size": 0.2505718767642975,
"loss_pdi": 0.5593925356864929,
"loss_ee": 0.8196510195732116,
"loss_delivery": 0.7932070523500443,
"loss_biodist": 0.27653754949569703,
"loss_toxic": 0.08169759791344404
},
{
"loss": 2.644732141494751,
"loss_size": 0.27979295402765275,
"loss_pdi": 0.5457461476325989,
"loss_ee": 0.7845215618610382,
"loss_delivery": 0.722122372686863,
"loss_biodist": 0.2437703028321266,
"loss_toxic": 0.06877880096435547
},
{
"loss": 2.5743841886520387,
"loss_size": 0.21236803606152535,
"loss_pdi": 0.5281321376562118,
"loss_ee": 0.7772053182125092,
"loss_delivery": 0.7842913195490837,
"loss_biodist": 0.20931598618626596,
"loss_toxic": 0.0630713876336813
},
{
"loss": 2.493379771709442,
"loss_size": 0.2545281477272511,
"loss_pdi": 0.514763566851616,
"loss_ee": 0.7416582465171814,
"loss_delivery": 0.7315813854336739,
"loss_biodist": 0.18844463676214218,
"loss_toxic": 0.062403830140829085
},
{
"loss": 2.3714203119277952,
"loss_size": 0.21288565024733544,
"loss_pdi": 0.5149440914392471,
"loss_ee": 0.7432775914669036,
"loss_delivery": 0.6615208894014358,
"loss_biodist": 0.17799324095249175,
"loss_toxic": 0.06079882858321071
},
{
"loss": 2.3138927936553957,
"loss_size": 0.22406778559088708,
"loss_pdi": 0.5060430943965912,
"loss_ee": 0.7270951688289642,
"loss_delivery": 0.6268678307533264,
"loss_biodist": 0.17946239709854125,
"loss_toxic": 0.050356499617919326
},
{
"loss": 2.2404407501220702,
"loss_size": 0.23460092321038245,
"loss_pdi": 0.4892877459526062,
"loss_ee": 0.6908941030502319,
"loss_delivery": 0.6124202072620392,
"loss_biodist": 0.16842604279518128,
"loss_toxic": 0.044811736792325974
},
{
"loss": 2.2448294520378114,
"loss_size": 0.21119624376296997,
"loss_pdi": 0.479864901304245,
"loss_ee": 0.6906192302703857,
"loss_delivery": 0.6555144399404526,
"loss_biodist": 0.16310803219676018,
"loss_toxic": 0.044526621932163835
},
{
"loss": 2.1580574989318846,
"loss_size": 0.18697498068213464,
"loss_pdi": 0.48660930395126345,
"loss_ee": 0.6810935467481614,
"loss_delivery": 0.6051739566028118,
"loss_biodist": 0.15406969040632248,
"loss_toxic": 0.04413598729297519
},
{
"loss": 2.114891529083252,
"loss_size": 0.17799586579203605,
"loss_pdi": 0.4589719235897064,
"loss_ee": 0.6686563313007354,
"loss_delivery": 0.6179293170571327,
"loss_biodist": 0.1526280902326107,
"loss_toxic": 0.03870999766513705
},
{
"loss": 2.1680126667022703,
"loss_size": 0.18272313922643663,
"loss_pdi": 0.47693236321210863,
"loss_ee": 0.6723115026950837,
"loss_delivery": 0.6574018053710461,
"loss_biodist": 0.14306045994162558,
"loss_toxic": 0.03558342705946416
},
{
"loss": 2.0243090748786927,
"loss_size": 0.19451010078191758,
"loss_pdi": 0.46296934187412264,
"loss_ee": 0.6654580652713775,
"loss_delivery": 0.5195972554385662,
"loss_biodist": 0.1416195034980774,
"loss_toxic": 0.04015482016839087
},
{
"loss": 1.980038857460022,
"loss_size": 0.19992023780941964,
"loss_pdi": 0.4373833805322647,
"loss_ee": 0.6562270969152451,
"loss_delivery": 0.5170416861772538,
"loss_biodist": 0.13248837292194365,
"loss_toxic": 0.03697808152064681
},
{
"loss": 2.0073827385902403,
"loss_size": 0.17545675858855247,
"loss_pdi": 0.43559625148773196,
"loss_ee": 0.6394164443016053,
"loss_delivery": 0.5809337809681893,
"loss_biodist": 0.13504885137081146,
"loss_toxic": 0.040930699557065964
}
],
"val": [
{
"loss": 10.945204257965088,
"loss_size": 6.681218147277832,
"loss_pdi": 1.0216107964515686,
"loss_ee": 1.0486068725585938,
"loss_delivery": 0.4687899202108383,
"loss_biodist": 1.215298354625702,
"loss_toxic": 0.5096809715032578,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.6470588235294118,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.0456085205078125,
"loss_size": 0.3470493406057358,
"loss_pdi": 0.7843169867992401,
"loss_ee": 0.8336820006370544,
"loss_delivery": 0.42519159615039825,
"loss_biodist": 1.1528617143630981,
"loss_toxic": 0.5025068372488022,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.553251266479492,
"loss_size": 0.07565776817500591,
"loss_pdi": 0.5630811750888824,
"loss_ee": 0.6947644650936127,
"loss_delivery": 0.4020952582359314,
"loss_biodist": 1.1898333430290222,
"loss_toxic": 0.627819336950779,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.2876256704330444,
"loss_size": 0.056399866938591,
"loss_pdi": 0.579200953245163,
"loss_ee": 0.5947848558425903,
"loss_delivery": 0.4561047703027725,
"loss_biodist": 1.0357274413108826,
"loss_toxic": 0.5654077678918839,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.2625988721847534,
"loss_size": 0.11296019703149796,
"loss_pdi": 0.5352367609739304,
"loss_ee": 0.6021667718887329,
"loss_delivery": 0.5046610683202744,
"loss_biodist": 1.0080225467681885,
"loss_toxic": 0.4995514266192913,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.165306806564331,
"loss_size": 0.09038551151752472,
"loss_pdi": 0.5058617442846298,
"loss_ee": 0.6476156711578369,
"loss_delivery": 0.43027013540267944,
"loss_biodist": 1.0445645153522491,
"loss_toxic": 0.4466092698276043,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.2408690452575684,
"loss_size": 0.22243183851242065,
"loss_pdi": 0.4985402673482895,
"loss_ee": 0.571824312210083,
"loss_delivery": 0.43825456500053406,
"loss_biodist": 0.9937507510185242,
"loss_toxic": 0.5160673335194588,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.5194804668426514,
"loss_size": 0.36968255043029785,
"loss_pdi": 0.4991031885147095,
"loss_ee": 0.5797468274831772,
"loss_delivery": 0.5644859671592712,
"loss_biodist": 0.9723091125488281,
"loss_toxic": 0.5341527052223682,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.539685606956482,
"loss_size": 0.38313066959381104,
"loss_pdi": 0.528433233499527,
"loss_ee": 0.5810057669878006,
"loss_delivery": 0.44039086997509,
"loss_biodist": 1.0017918348312378,
"loss_toxic": 0.6049331650137901,
"acc_pdi": 0.8823529411764706,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.9403724670410156,
"loss_size": 0.6225972771644592,
"loss_pdi": 0.5688649713993073,
"loss_ee": 0.6205386221408844,
"loss_delivery": 0.6095166206359863,
"loss_biodist": 0.9419751763343811,
"loss_toxic": 0.576879795640707,
"acc_pdi": 0.803921568627451,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.8653494119644165,
"loss_size": 0.6294703483581543,
"loss_pdi": 0.5615053772926331,
"loss_ee": 0.6072992980480194,
"loss_delivery": 0.47824281454086304,
"loss_biodist": 0.964938759803772,
"loss_toxic": 0.6238927394151688,
"acc_pdi": 0.8431372549019608,
"acc_ee": 0.8431372549019608,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.109289169311523,
"loss_size": 0.753549188375473,
"loss_pdi": 0.6232334971427917,
"loss_ee": 0.6752453744411469,
"loss_delivery": 0.4541686922311783,
"loss_biodist": 1.0001222491264343,
"loss_toxic": 0.6029700562357903,
"acc_pdi": 0.7450980392156863,
"acc_ee": 0.7254901960784313,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.084217548370361,
"loss_size": 0.6689053475856781,
"loss_pdi": 0.5947604179382324,
"loss_ee": 0.6819752305746078,
"loss_delivery": 0.5174736380577087,
"loss_biodist": 0.9870622158050537,
"loss_toxic": 0.6340407878160477,
"acc_pdi": 0.7843137254901961,
"acc_ee": 0.7647058823529411,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.762814164161682,
"loss_size": 0.5682831406593323,
"loss_pdi": 0.5777421444654465,
"loss_ee": 0.7156199663877487,
"loss_delivery": 0.44971026480197906,
"loss_biodist": 0.9156049191951752,
"loss_toxic": 0.5358536541461945,
"acc_pdi": 0.7450980392156863,
"acc_ee": 0.7058823529411765,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.147223711013794,
"loss_size": 0.6644828915596008,
"loss_pdi": 0.5911359935998917,
"loss_ee": 0.713784396648407,
"loss_delivery": 0.500703439116478,
"loss_biodist": 1.0310384333133698,
"loss_toxic": 0.6460786163806915,
"acc_pdi": 0.7647058823529411,
"acc_ee": 0.7058823529411765,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.157853841781616,
"loss_size": 0.7414849102497101,
"loss_pdi": 0.630668580532074,
"loss_ee": 0.6938402056694031,
"loss_delivery": 0.50765261054039,
"loss_biodist": 0.9891600012779236,
"loss_toxic": 0.5950475558638573,
"acc_pdi": 0.7450980392156863,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.2473719120025635,
"loss_size": 0.8058429956436157,
"loss_pdi": 0.5982940196990967,
"loss_ee": 0.7209844589233398,
"loss_delivery": 0.5253763496875763,
"loss_biodist": 0.973820835351944,
"loss_toxic": 0.6230533868074417,
"acc_pdi": 0.803921568627451,
"acc_ee": 0.6862745098039216,
"acc_toxic": 0.851063829787234
},
{
"loss": 3.90485680103302,
"loss_size": 0.5400884747505188,
"loss_pdi": 0.5671974420547485,
"loss_ee": 0.7085212767124176,
"loss_delivery": 0.5078988373279572,
"loss_biodist": 0.9940473735332489,
"loss_toxic": 0.5871035009622574,
"acc_pdi": 0.803921568627451,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.193094968795776,
"loss_size": 0.764777421951294,
"loss_pdi": 0.5734306275844574,
"loss_ee": 0.7070393562316895,
"loss_delivery": 0.5335722267627716,
"loss_biodist": 1.0060182809829712,
"loss_toxic": 0.6082571670413017,
"acc_pdi": 0.8235294117647058,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.225732326507568,
"loss_size": 0.7807798981666565,
"loss_pdi": 0.57969930768013,
"loss_ee": 0.7046914398670197,
"loss_delivery": 0.5619150400161743,
"loss_biodist": 1.0033797025680542,
"loss_toxic": 0.5952669233083725,
"acc_pdi": 0.7843137254901961,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
},
{
"loss": 4.122166633605957,
"loss_size": 0.727344423532486,
"loss_pdi": 0.5855642706155777,
"loss_ee": 0.7065156400203705,
"loss_delivery": 0.5201583206653595,
"loss_biodist": 0.9953437149524689,
"loss_toxic": 0.5872401967644691,
"acc_pdi": 0.803921568627451,
"acc_ee": 0.7450980392156863,
"acc_toxic": 0.851063829787234
}
]
}

Binary file not shown.

View File

@ -1,384 +0,0 @@
{
"train": [
{
"loss": 19.51876787705855,
"loss_size": 14.430041096427225,
"loss_pdi": 1.3163191405209629,
"loss_ee": 1.080450177192688,
"loss_delivery": 0.8366337662393396,
"loss_biodist": 1.2393495603041216,
"loss_toxic": 0.6159739873626016
},
{
"loss": 7.450059110468084,
"loss_size": 3.0973785898902197,
"loss_pdi": 1.0574429847977378,
"loss_ee": 0.972686382857236,
"loss_delivery": 0.8131051341241057,
"loss_biodist": 1.0609121160073713,
"loss_toxic": 0.4485339197245511
},
{
"loss": 4.450224074450406,
"loss_size": 0.4227097576314753,
"loss_pdi": 0.8444447679953142,
"loss_ee": 0.9470934163440358,
"loss_delivery": 1.0848771997473456,
"loss_biodist": 0.8853748061440208,
"loss_toxic": 0.26572414352135226
},
{
"loss": 3.697209119796753,
"loss_size": 0.27599087357521057,
"loss_pdi": 0.7228363969109275,
"loss_ee": 0.9080322655764493,
"loss_delivery": 0.8549216186458414,
"loss_biodist": 0.7190906730565158,
"loss_toxic": 0.21633729677308688
},
{
"loss": 3.5547448938543145,
"loss_size": 0.29429094425656577,
"loss_pdi": 0.6620089682665738,
"loss_ee": 0.8595339439131997,
"loss_delivery": 0.9978644847869873,
"loss_biodist": 0.5649206800894304,
"loss_toxic": 0.17612605541944504
},
{
"loss": 3.069189115004106,
"loss_size": 0.30328830602494156,
"loss_pdi": 0.6169473230838776,
"loss_ee": 0.8132463910362937,
"loss_delivery": 0.72551099143245,
"loss_biodist": 0.46287189017642627,
"loss_toxic": 0.14732415906407617
},
{
"loss": 3.1349260156804863,
"loss_size": 0.27535233172503387,
"loss_pdi": 0.5893999202684923,
"loss_ee": 0.807813747362657,
"loss_delivery": 0.9538742283528502,
"loss_biodist": 0.4018077091737227,
"loss_toxic": 0.1066781035201116
},
{
"loss": 2.6963415037501943,
"loss_size": 0.276088840582154,
"loss_pdi": 0.5490301495248621,
"loss_ee": 0.757920276034962,
"loss_delivery": 0.6706068068742752,
"loss_biodist": 0.34504769336093555,
"loss_toxic": 0.09764780781485817
},
{
"loss": 2.418043158271096,
"loss_size": 0.25799565559083765,
"loss_pdi": 0.5286295684901151,
"loss_ee": 0.73835120417855,
"loss_delivery": 0.5078353543173183,
"loss_biodist": 0.2939920425415039,
"loss_toxic": 0.09123929352922873
},
{
"loss": 2.294130650433627,
"loss_size": 0.20914554325017062,
"loss_pdi": 0.5159178945151243,
"loss_ee": 0.7331724437800321,
"loss_delivery": 0.50414734875614,
"loss_biodist": 0.2559647980061444,
"loss_toxic": 0.07578264227644964
},
{
"loss": 2.260723189874129,
"loss_size": 0.21194299920038742,
"loss_pdi": 0.51129734787074,
"loss_ee": 0.71018939668482,
"loss_delivery": 0.506003974513574,
"loss_biodist": 0.24507361785932022,
"loss_toxic": 0.07621594924818385
},
{
"loss": 2.180013732476668,
"loss_size": 0.21782933243296362,
"loss_pdi": 0.5094848789952018,
"loss_ee": 0.6991291533816945,
"loss_delivery": 0.4610004154118625,
"loss_biodist": 0.22516351396387274,
"loss_toxic": 0.06740644057704644
},
{
"loss": 2.131091995672746,
"loss_size": 0.22081551971760663,
"loss_pdi": 0.4984923790801655,
"loss_ee": 0.6744041009382769,
"loss_delivery": 0.44286114688624034,
"loss_biodist": 0.2276345125653527,
"loss_toxic": 0.0668843225999312
},
{
"loss": 2.075855114243247,
"loss_size": 0.1967273937030272,
"loss_pdi": 0.4761518023230813,
"loss_ee": 0.6580501876094125,
"loss_delivery": 0.4651151258837093,
"loss_biodist": 0.22098628905686465,
"loss_toxic": 0.05882429365407337
},
{
"loss": 2.070832209153609,
"loss_size": 0.22619221427223898,
"loss_pdi": 0.46735330332409253,
"loss_ee": 0.658088050105355,
"loss_delivery": 0.46364551173015073,
"loss_biodist": 0.1992648494514552,
"loss_toxic": 0.0562882690097798
},
{
"loss": 2.0497749502008613,
"loss_size": 0.2018005665053021,
"loss_pdi": 0.44933846592903137,
"loss_ee": 0.6409396637569774,
"loss_delivery": 0.4817944710904902,
"loss_biodist": 0.21351950141516599,
"loss_toxic": 0.0623823038556359
},
{
"loss": 1.998817953196439,
"loss_size": 0.19020642136985605,
"loss_pdi": 0.4579390991817821,
"loss_ee": 0.6345709101720289,
"loss_delivery": 0.45773128284649417,
"loss_biodist": 0.20633393864740024,
"loss_toxic": 0.052036324177276
},
{
"loss": 1.9732873006300493,
"loss_size": 0.18214904246005145,
"loss_pdi": 0.46314583312381397,
"loss_ee": 0.6480948545716025,
"loss_delivery": 0.43798652359030465,
"loss_biodist": 0.19162344119765543,
"loss_toxic": 0.05028762956234542
}
],
"val": [
{
"loss": 11.670351346333822,
"loss_size": 7.447449843088786,
"loss_pdi": 1.082938313484192,
"loss_ee": 0.9422469735145569,
"loss_delivery": 0.7185012102127075,
"loss_biodist": 0.945093979438146,
"loss_toxic": 0.5341211756070455,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 3.7281323273976645,
"loss_size": 0.4031377931435903,
"loss_pdi": 0.7136062383651733,
"loss_ee": 0.7757165829340616,
"loss_delivery": 0.6889261901378632,
"loss_biodist": 0.8803721169630686,
"loss_toxic": 0.2663734555244446,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 2.996154228846232,
"loss_size": 0.08968868106603622,
"loss_pdi": 0.5251734455426534,
"loss_ee": 0.705430785814921,
"loss_delivery": 0.7259814739227295,
"loss_biodist": 0.8580058316389719,
"loss_toxic": 0.09187404563029607,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 3.0241867701212564,
"loss_size": 0.06373066206773122,
"loss_pdi": 0.5075281461079916,
"loss_ee": 0.6890556613604227,
"loss_delivery": 0.8740459084510803,
"loss_biodist": 0.8290574749310812,
"loss_toxic": 0.060768917202949524,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 3.2901877562204995,
"loss_size": 0.20215384662151337,
"loss_pdi": 0.5433482428391775,
"loss_ee": 0.7523069183031718,
"loss_delivery": 0.9290379285812378,
"loss_biodist": 0.8056556979815165,
"loss_toxic": 0.057684975365797676,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8333333333333334,
"acc_toxic": 1.0
},
{
"loss": 3.11362091700236,
"loss_size": 0.11133595556020737,
"loss_pdi": 0.5679469505945841,
"loss_ee": 0.8252793351809183,
"loss_delivery": 0.8343843619028727,
"loss_biodist": 0.7218613227208456,
"loss_toxic": 0.052812947581211724,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.8181818181818182,
"acc_toxic": 1.0
},
{
"loss": 3.247321446736654,
"loss_size": 0.09277657171090443,
"loss_pdi": 0.6567125717798868,
"loss_ee": 1.0444208979606628,
"loss_delivery": 0.7062844236691793,
"loss_biodist": 0.6986102362473806,
"loss_toxic": 0.04851680745681127,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.45454545454545453,
"acc_toxic": 1.0
},
{
"loss": 3.168424208958944,
"loss_size": 0.05177713930606842,
"loss_pdi": 0.5932339330514272,
"loss_ee": 0.968136191368103,
"loss_delivery": 0.7594618300596873,
"loss_biodist": 0.7631273567676544,
"loss_toxic": 0.032687741021315254,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.45454545454545453,
"acc_toxic": 1.0
},
{
"loss": 3.1226733525594077,
"loss_size": 0.16442706187566122,
"loss_pdi": 0.4861932198206584,
"loss_ee": 0.8927785356839498,
"loss_delivery": 0.809323231379191,
"loss_biodist": 0.7391951779524485,
"loss_toxic": 0.030756143853068352,
"acc_pdi": 0.8484848484848485,
"acc_ee": 0.4696969696969697,
"acc_toxic": 1.0
},
{
"loss": 3.4750285943349204,
"loss_size": 0.07188746457298596,
"loss_pdi": 0.635328451792399,
"loss_ee": 1.0510863463083904,
"loss_delivery": 0.9019280473391215,
"loss_biodist": 0.7850770453612009,
"loss_toxic": 0.02972123461465041,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 3.3856334686279297,
"loss_size": 0.0857744167248408,
"loss_pdi": 0.5498887598514557,
"loss_ee": 0.9292206565539042,
"loss_delivery": 1.003569980462392,
"loss_biodist": 0.7962689697742462,
"loss_toxic": 0.020910644593338173,
"acc_pdi": 0.8333333333333334,
"acc_ee": 0.3787878787878788,
"acc_toxic": 1.0
},
{
"loss": 3.594546397527059,
"loss_size": 0.058230139315128326,
"loss_pdi": 0.7094775040944418,
"loss_ee": 1.134681224822998,
"loss_delivery": 0.8755488395690918,
"loss_biodist": 0.7928893665472666,
"loss_toxic": 0.023719362293680508,
"acc_pdi": 0.6363636363636364,
"acc_ee": 0.21212121212121213,
"acc_toxic": 1.0
},
{
"loss": 3.34331480662028,
"loss_size": 0.09952588627735774,
"loss_pdi": 0.5444782872994741,
"loss_ee": 0.9434934655825297,
"loss_delivery": 0.9477877616882324,
"loss_biodist": 0.7897172272205353,
"loss_toxic": 0.018312191901107628,
"acc_pdi": 0.7878787878787878,
"acc_ee": 0.36363636363636365,
"acc_toxic": 1.0
},
{
"loss": 3.4212222894032798,
"loss_size": 0.08121616393327713,
"loss_pdi": 0.5517565310001373,
"loss_ee": 1.0685155193010967,
"loss_delivery": 0.874200721581777,
"loss_biodist": 0.828464408715566,
"loss_toxic": 0.017068898615737755,
"acc_pdi": 0.8181818181818182,
"acc_ee": 0.25757575757575757,
"acc_toxic": 1.0
},
{
"loss": 3.6395487785339355,
"loss_size": 0.07888962080081303,
"loss_pdi": 0.5913220842679342,
"loss_ee": 1.0437468489011128,
"loss_delivery": 1.0660852392514546,
"loss_biodist": 0.8421931266784668,
"loss_toxic": 0.01731194742023945,
"acc_pdi": 0.7878787878787878,
"acc_ee": 0.2727272727272727,
"acc_toxic": 1.0
},
{
"loss": 3.5140305360158286,
"loss_size": 0.0599971575041612,
"loss_pdi": 0.5561938285827637,
"loss_ee": 1.0674984057744343,
"loss_delivery": 0.9739653070767721,
"loss_biodist": 0.8400343159834543,
"loss_toxic": 0.01634151643762986,
"acc_pdi": 0.7424242424242424,
"acc_ee": 0.22727272727272727,
"acc_toxic": 1.0
},
{
"loss": 3.628753344217936,
"loss_size": 0.08887146785855293,
"loss_pdi": 0.582503984371821,
"loss_ee": 1.114095131556193,
"loss_delivery": 0.9745903412501017,
"loss_biodist": 0.8516232868035635,
"loss_toxic": 0.017069284183283646,
"acc_pdi": 0.7121212121212122,
"acc_ee": 0.19696969696969696,
"acc_toxic": 1.0
},
{
"loss": 3.6391177972157798,
"loss_size": 0.08282352735598882,
"loss_pdi": 0.581497848033905,
"loss_ee": 1.141685386498769,
"loss_delivery": 0.9548555612564087,
"loss_biodist": 0.8643284440040588,
"loss_toxic": 0.013927079737186432,
"acc_pdi": 0.7424242424242424,
"acc_ee": 0.16666666666666666,
"acc_toxic": 1.0
}
]
}

Binary file not shown.

View File

@ -1,294 +0,0 @@
{
"fold_results": [
{
"fold_idx": 0,
"n_samples": 95,
"size": {
"n": 95,
"rmse": 0.5909209144067168,
"mae": 0.376253614927593,
"r2": 0.005712927161997228
},
"delivery": {
"n": 66,
"rmse": 1.3280577883458438,
"mae": 0.5195405159964028,
"r2": 0.03195999739694366
},
"pdi": {
"n": 95,
"accuracy": 0.6105263157894737,
"precision": 0.20350877192982456,
"recall": 0.3333333333333333,
"f1": 0.25272331154684097
},
"ee": {
"n": 95,
"accuracy": 0.6736842105263158,
"precision": 0.22456140350877193,
"recall": 0.3333333333333333,
"f1": 0.26834381551362685
},
"toxic": {
"n": 66,
"accuracy": 0.8939393939393939,
"precision": 0.44696969696969696,
"recall": 0.5,
"f1": 0.472
},
"biodist": {
"n": 66,
"kl_divergence": 0.851655784204727,
"js_divergence": 0.21404831573756974
}
},
{
"fold_idx": 1,
"n_samples": 195,
"size": {
"n": 193,
"rmse": 0.4425801645813746,
"mae": 0.26432527161632796,
"r2": -0.026225211870033682
},
"delivery": {
"n": 123,
"rmse": 0.7771322048436382,
"mae": 0.6133777339870822,
"r2": -0.128644776760948
},
"pdi": {
"n": 195,
"accuracy": 0.7076923076923077,
"precision": 0.35384615384615387,
"recall": 0.5,
"f1": 0.4144144144144144
},
"ee": {
"n": 195,
"accuracy": 0.4205128205128205,
"precision": 0.14017094017094017,
"recall": 0.3333333333333333,
"f1": 0.19735258724428398
},
"toxic": {
"n": 123,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n": 123,
"kl_divergence": 0.9336461102028436,
"js_divergence": 0.24870266224462317
}
},
{
"fold_idx": 2,
"n_samples": 51,
"size": {
"n": 51,
"rmse": 0.6473513298834871,
"mae": 0.5600235602434944,
"r2": -9.27515642706235
},
"delivery": {
"n": 44,
"rmse": 0.7721077356414991,
"mae": 0.6167582499593581,
"r2": -0.4822886602727561
},
"pdi": {
"n": 51,
"accuracy": 0.8823529411764706,
"precision": 0.29411764705882354,
"recall": 0.3333333333333333,
"f1": 0.3125
},
"ee": {
"n": 51,
"accuracy": 0.8431372549019608,
"precision": 0.28104575163398693,
"recall": 0.3333333333333333,
"f1": 0.3049645390070922
},
"toxic": {
"n": 47,
"accuracy": 0.851063829787234,
"precision": 0.425531914893617,
"recall": 0.5,
"f1": 0.4597701149425288
},
"biodist": {
"n": 45,
"kl_divergence": 1.1049896129018548,
"js_divergence": 0.25485248115851133
}
},
{
"fold_idx": 3,
"n_samples": 66,
"size": {
"n": 66,
"rmse": 0.2407212117920812,
"mae": 0.19363613562150436,
"r2": -0.11204941379936861
},
"delivery": {
"n": 62,
"rmse": 1.0041711455927012,
"mae": 0.7132550483914993,
"r2": -0.63265374674746
},
"pdi": {
"n": 66,
"accuracy": 0.8484848484848485,
"precision": 0.42424242424242425,
"recall": 0.5,
"f1": 0.4590163934426229
},
"ee": {
"n": 66,
"accuracy": 0.8181818181818182,
"precision": 0.27692307692307694,
"recall": 0.32727272727272727,
"f1": 0.3
},
"toxic": {
"n": 62,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n": 62,
"kl_divergence": 0.9677978984139058,
"js_divergence": 0.2020309307244639
}
},
{
"fold_idx": 4,
"n_samples": 27,
"size": {
"n": 27,
"rmse": 0.23392834445509142,
"mae": 0.19066280788845485,
"r2": -0.2667651950955112
},
"delivery": {
"n": 15,
"rmse": 1.9603892288630869,
"mae": 1.3892907698949177,
"r2": -0.29760739742916287
},
"pdi": {
"n": 27,
"accuracy": 0.8888888888888888,
"precision": 0.4444444444444444,
"recall": 0.5,
"f1": 0.47058823529411764
},
"ee": {
"n": 27,
"accuracy": 0.5925925925925926,
"precision": 0.19753086419753085,
"recall": 0.3333333333333333,
"f1": 0.24806201550387597
},
"toxic": {
"n": 15,
"accuracy": 1.0,
"precision": 1.0,
"recall": 1.0,
"f1": 1.0
},
"biodist": {
"n": 15,
"kl_divergence": 0.9389607012315264,
"js_divergence": 0.2470218476598176
}
}
],
"summary_stats": {
"size": {
"rmse_mean": 0.43110039302375025,
"rmse_std": 0.17179051271013462,
"r2_mean": -1.9348966641330534,
"r2_std": 3.6713441784129
},
"delivery": {
"rmse_mean": 1.1683716206573538,
"rmse_std": 0.4449374578352648,
"r2_mean": -0.30184691676267666,
"r2_std": 0.23809090378746706
},
"pdi": {
"accuracy_mean": 0.7875890604063979,
"accuracy_std": 0.11016791908756088,
"f1_mean": 0.3818484709395992,
"f1_std": 0.08529090446864619
},
"ee": {
"accuracy_mean": 0.6696217393431015,
"accuracy_std": 0.15503740047242787,
"f1_mean": 0.2637445914537758,
"f1_std": 0.039213602228007696
},
"toxic": {
"accuracy_mean": 0.9490006447453256,
"accuracy_std": 0.06391582554207781,
"f1_mean": 0.7863540229885058,
"f1_std": 0.26169039387919035
},
"biodist": {
"kl_mean": 0.9594100213909715,
"kl_std": 0.08240959093662605,
"js_mean": 0.23333124750499712,
"js_std": 0.021158533549255752
}
},
"overall": {
"size": {
"n_samples": 432,
"mse": 0.22604480336185886,
"rmse": 0.47544169291497657,
"mae": 0.3084443360567093,
"r2": -0.2313078534105617
},
"delivery": {
"n_samples": 310,
"mse": 1.0873755440675295,
"rmse": 1.0427730069710903,
"mae": 0.6513989447841361,
"r2": -0.09443640807387799
},
"pdi": {
"n_samples": 434,
"accuracy": 0.7396313364055299,
"precision": 0.18490783410138248,
"recall": 0.25,
"f1": 0.21258278145695364
},
"ee": {
"n_samples": 434,
"accuracy": 0.5967741935483871,
"precision": 0.1993841416474211,
"recall": 0.33205128205128204,
"f1": 0.24915824915824913
},
"toxic": {
"n_samples": 313,
"accuracy": 0.9552715654952076,
"precision": 0.4776357827476038,
"recall": 0.5,
"f1": 0.48856209150326796
},
"biodist": {
"n_samples": 311,
"kl_divergence": 0.9481034280166569,
"js_divergence": 0.23285280825310384
}
}
}

Binary file not shown.

View File

@ -1,206 +1,310 @@
{
"train": [
{
"loss": 0.7730368412685099,
"n_samples": 6783
"loss": 0.8244676398801744,
"n_samples": 8721
},
{
"loss": 0.658895703010919,
"n_samples": 6783
"loss": 0.6991508170533461,
"n_samples": 8721
},
{
"loss": 0.6059015260392299,
"n_samples": 6783
"loss": 0.6388374940987616,
"n_samples": 8721
},
{
"loss": 0.5744731174349416,
"n_samples": 6783
"loss": 0.6008581508669937,
"n_samples": 8721
},
{
"loss": 0.5452056020458733,
"n_samples": 6783
"loss": 0.584832567446085,
"n_samples": 8721
},
{
"loss": 0.5138543470936083,
"n_samples": 6783
"loss": 0.5481657371815157,
"n_samples": 8721
},
{
"loss": 0.4885380559178135,
"n_samples": 6783
"loss": 0.5368926340308079,
"n_samples": 8721
},
{
"loss": 0.47587182296687974,
"n_samples": 6783
"loss": 0.5210388793613561,
"n_samples": 8721
},
{
"loss": 0.4671051038255316,
"n_samples": 6783
"loss": 0.49758357966374045,
"n_samples": 8721
},
{
"loss": 0.46794115915756107,
"n_samples": 6783
"loss": 0.49256294099457043,
"n_samples": 8721
},
{
"loss": 0.4293930456997915,
"n_samples": 6783
"loss": 0.4697267088016886,
"n_samples": 8721
},
{
"loss": 0.42624105651716415,
"n_samples": 6783
"loss": 0.45763822707571084,
"n_samples": 8721
},
{
"loss": 0.4131358770446828,
"n_samples": 6783
"loss": 0.4495221330627172,
"n_samples": 8721
},
{
"loss": 0.3946074267790835,
"n_samples": 6783
"loss": 0.446159594079631,
"n_samples": 8721
},
{
"loss": 0.3898155013755344,
"n_samples": 6783
"loss": 0.4327090857889029,
"n_samples": 8721
},
{
"loss": 0.37861797005733383,
"n_samples": 6783
"loss": 0.4249273364101852,
"n_samples": 8721
},
{
"loss": 0.3775682858392304,
"n_samples": 6783
"loss": 0.4216959138704459,
"n_samples": 8721
},
{
"loss": 0.3800349080262064,
"n_samples": 6783
"loss": 0.416526201182502,
"n_samples": 8721
},
{
"loss": 0.36302345173031675,
"n_samples": 6783
"loss": 0.40368679039741573,
"n_samples": 8721
},
{
"loss": 0.3429561740842766,
"n_samples": 6783
"loss": 0.4051084730032182,
"n_samples": 8721
},
{
"loss": 0.3445638883004898,
"n_samples": 6783
"loss": 0.38971701020385785,
"n_samples": 8721
},
{
"loss": 0.318970229203733,
"n_samples": 6783
"loss": 0.39155546386038786,
"n_samples": 8721
},
{
"loss": 0.30179278279904437,
"n_samples": 6783
"loss": 0.37976963541784114,
"n_samples": 8721
},
{
"loss": 0.2887343142006437,
"n_samples": 6783
"loss": 0.36484339719805037,
"n_samples": 8721
},
{
"loss": 0.29240367556855545,
"n_samples": 6783
"loss": 0.36232607571196496,
"n_samples": 8721
},
{
"loss": 0.3345973272380199,
"n_samples": 8721
},
{
"loss": 0.31767916518768957,
"n_samples": 8721
},
{
"loss": 0.32065429246052457,
"n_samples": 8721
},
{
"loss": 0.3171297926146043,
"n_samples": 8721
},
{
"loss": 0.3122120894173009,
"n_samples": 8721
},
{
"loss": 0.3135035038404461,
"n_samples": 8721
},
{
"loss": 0.2987745178222875,
"n_samples": 8721
},
{
"loss": 0.2914867957853393,
"n_samples": 8721
},
{
"loss": 0.2983839795507705,
"n_samples": 8721
},
{
"loss": 0.2826709597875678,
"n_samples": 8721
},
{
"loss": 0.2731766632569382,
"n_samples": 8721
},
{
"loss": 0.27726896305742266,
"n_samples": 8721
},
{
"loss": 0.27864557847067956,
"n_samples": 8721
}
],
"val": [
{
"loss": 0.7350345371841441,
"n_samples": 2907
"loss": 0.7601077516012517,
"n_samples": 969
},
{
"loss": 0.7165568811318536,
"n_samples": 2907
"loss": 0.7119935319611901,
"n_samples": 969
},
{
"loss": 0.7251406249862214,
"n_samples": 2907
"loss": 0.6461842978148269,
"n_samples": 969
},
{
"loss": 0.6836505264587159,
"n_samples": 2907
"loss": 0.7006978391063226,
"n_samples": 969
},
{
"loss": 0.6747132955771933,
"n_samples": 2907
"loss": 0.6533874032943979,
"n_samples": 969
},
{
"loss": 0.6691136244936912,
"n_samples": 2907
"loss": 0.6413641451743611,
"n_samples": 969
},
{
"loss": 0.6337480902323249,
"n_samples": 2907
"loss": 0.6168395132979742,
"n_samples": 969
},
{
"loss": 0.6600317959527934,
"n_samples": 2907
"loss": 0.6095251602162025,
"n_samples": 969
},
{
"loss": 0.6439923948855346,
"n_samples": 2907
"loss": 0.5887809592626905,
"n_samples": 969
},
{
"loss": 0.643800035575267,
"n_samples": 2907
"loss": 0.5655298325376368,
"n_samples": 969
},
{
"loss": 0.6181512585221839,
"n_samples": 2907
"loss": 0.5809201743872788,
"n_samples": 969
},
{
"loss": 0.6442458634939151,
"n_samples": 2907
"loss": 0.5897585974912033,
"n_samples": 969
},
{
"loss": 0.6344759362359862,
"n_samples": 2907
"loss": 0.5732012489662573,
"n_samples": 969
},
{
"loss": 0.6501405371457472,
"n_samples": 2907
"loss": 0.5607388911786094,
"n_samples": 969
},
{
"loss": 0.6098835162990152,
"n_samples": 2907
"loss": 0.5717580675371414,
"n_samples": 969
},
{
"loss": 0.6366627322138894,
"n_samples": 2907
"loss": 0.5553950037657291,
"n_samples": 969
},
{
"loss": 0.6171610150646417,
"n_samples": 2907
"loss": 0.5778171792857049,
"n_samples": 969
},
{
"loss": 0.6358801012273748,
"n_samples": 2907
"loss": 0.5602665468127734,
"n_samples": 969
},
{
"loss": 0.6239976831059871,
"n_samples": 2907
"loss": 0.5475307451359259,
"n_samples": 969
},
{
"loss": 0.6683828232827201,
"n_samples": 2907
"loss": 0.551515599314827,
"n_samples": 969
},
{
"loss": 0.6655785786478143,
"n_samples": 2907
"loss": 0.5755438121541243,
"n_samples": 969
},
{
"loss": 0.6152775046503088,
"n_samples": 2907
"loss": 0.5798238261811381,
"n_samples": 969
},
{
"loss": 0.6202247662153858,
"n_samples": 2907
"loss": 0.5739961433828923,
"n_samples": 969
},
{
"loss": 0.648199727435189,
"n_samples": 2907
"loss": 0.5742599932540312,
"n_samples": 969
},
{
"loss": 0.6473217075085124,
"n_samples": 2907
"loss": 0.5834948123885382,
"n_samples": 969
},
{
"loss": 0.554078846570139,
"n_samples": 969
},
{
"loss": 0.5714933996322354,
"n_samples": 969
},
{
"loss": 0.5384107524350331,
"n_samples": 969
},
{
"loss": 0.570854394451568,
"n_samples": 969
},
{
"loss": 0.5767292551642478,
"n_samples": 969
},
{
"loss": 0.5660079547556808,
"n_samples": 969
},
{
"loss": 0.5608972411514312,
"n_samples": 969
},
{
"loss": 0.5620947442987263,
"n_samples": 969
},
{
"loss": 0.5706970894361305,
"n_samples": 969
},
{
"loss": 0.5702376298690974,
"n_samples": 969
},
{
"loss": 0.5758474825259579,
"n_samples": 969
},
{
"loss": 0.5673816067284844,
"n_samples": 969
},
{
"loss": 0.5671441179879925,
"n_samples": 969
}
]
}

View File

@ -1,224 +1,232 @@
"""内部数据 Cross-Validation 划分脚本:基于 Amine 的分组划分"""
"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式"""
from pathlib import Path
from typing import List
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import typer
from loguru import logger
from lnp_ml.config import INTERIM_DATA_DIR, PROCESSED_DATA_DIR
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import (
process_dataframe,
SMILES_COL,
LNPDatasetConfig,
COMP_COLS,
HELP_COLS,
TARGET_REGRESSION,
TARGET_CLASSIFICATION_PDI,
TARGET_CLASSIFICATION_EE,
TARGET_TOXIC,
TARGET_BIODIST,
get_phys_cols,
get_exp_cols,
EXP_ONEHOT_SPECS,
PHYS_ONEHOT_SPECS,
)
app = typer.Typer()
def amine_based_cv_split(
df: pd.DataFrame,
n_folds: int = 5,
seed: int = 42,
amine_col: str = "Amine",
) -> List[dict]:
# CV extra_x 列名到模型列名的映射
CV_COL_MAPPING = {
# Batch_or_individual_or_barcoded -> Sample_organization_type (for Value_name related)
"Batch_or_individual_or_barcoded_Barcoded": "Batch_or_individual_or_barcoded_Barcoded",
"Batch_or_individual_or_barcoded_Individual": "Batch_or_individual_or_barcoded_Individual",
# Helper_lipid_ID_None 不在模型中使用,忽略
}
def load_cv_split(cv_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
基于 Amine 列进行 Cross-Validation 划分
步骤
1. amine_col 分组
2. 打乱分组顺序
3. 将分组 round-robin 分配到 n_folds 个容器
4. 对于每个 fold i
- validation = container[i]
- test = container[(i+1) % n_folds]
- train = 其余所有
加载单个 CV split 的数据
Args:
df: 输入 DataFrame
n_folds: 折数
seed: 随机种子
amine_col: 用于分组的列名
cv_dir: CV split 目录 cv_0/
Returns:
List of dicts每个 dict 包含 train_df, val_df, test_df
(train_df, valid_df, test_df) 合并后的 DataFrame
"""
# 获取唯一的 amine 并打乱
unique_amines = df[amine_col].unique()
rng = np.random.RandomState(seed)
rng.shuffle(unique_amines)
logger.info(f"Found {len(unique_amines)} unique amines")
# Round-robin 分配到 n_folds 个容器
containers = [[] for _ in range(n_folds)]
for i, amine in enumerate(unique_amines):
containers[i % n_folds].append(amine)
# 打印每个容器的大小
for i, container in enumerate(containers):
container_samples = df[df[amine_col].isin(container)]
logger.info(f" Container {i}: {len(container)} amines, {len(container_samples)} samples")
# 生成每个 fold 的数据
fold_splits = []
for i in range(n_folds):
val_amines = set(containers[i])
test_amines = set(containers[(i + 1) % n_folds])
train_amines = set()
for j in range(n_folds):
if j != i and j != (i + 1) % n_folds:
train_amines.update(containers[j])
splits = {}
for split_name in ["train", "valid", "test"]:
# 加载主数据smiles, quantified_delivery
main_path = cv_dir / f"{split_name}.csv"
extra_x_path = cv_dir / f"{split_name}_extra_x.csv"
metadata_path = cv_dir / f"{split_name}_metadata.csv"
train_df = df[df[amine_col].isin(train_amines)].reset_index(drop=True)
val_df = df[df[amine_col].isin(val_amines)].reset_index(drop=True)
test_df = df[df[amine_col].isin(test_amines)].reset_index(drop=True)
if not main_path.exists():
raise FileNotFoundError(f"Missing {main_path}")
fold_splits.append({
"train": train_df,
"val": val_df,
"test": test_df,
})
main_df = pd.read_csv(main_path)
logger.info(
f"Fold {i}: train={len(train_df)} ({len(train_amines)} amines), "
f"val={len(val_df)} ({len(val_amines)} amines), "
f"test={len(test_df)} ({len(test_amines)} amines)"
)
# 加载 extra_x已 one-hot 编码的特征)
if extra_x_path.exists():
extra_x_df = pd.read_csv(extra_x_path)
# 确保行数一致
assert len(main_df) == len(extra_x_df), f"Row count mismatch: {len(main_df)} vs {len(extra_x_df)}"
# 合并(按行索引)
df = pd.concat([main_df, extra_x_df], axis=1)
else:
df = main_df
logger.warning(f" {split_name}_extra_x.csv not found, using main data only")
# 可选:从 metadata 获取额外信息
if metadata_path.exists():
metadata_df = pd.read_csv(metadata_path)
# 提取需要的列(如 Purity, Mix_type, Value_name 等)
for col in ["Purity", "Mix_type", "Value_name", "Target_or_delivered_gene"]:
if col in metadata_df.columns and col not in df.columns:
df[col] = metadata_df[col]
splits[split_name] = df
return fold_splits
return splits["train"], splits["valid"], splits["test"]
def process_cv_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
处理 CV 数据的 DataFrame对齐到模型所需的列格式
CV 数据的 extra_x 已经包含大部分 one-hot 编码但需要
1. 添加缺失的 one-hot 设为 0
2. metadata 中生成 phys token one-hot Purity, Mix_type, Cargo_type, Target_or_delivered_gene
3. 生成 Value_name one-hot
"""
df = df.copy()
# 1. 处理 comp 列
for col in COMP_COLS:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
else:
df[col] = 0.0
# 2. 处理 help 列
for col in HELP_COLS:
if col not in df.columns:
df[col] = 0.0
else:
df[col] = df[col].fillna(0.0).astype(float)
# 3. 处理 phys token 的 one-hot 列
for col, values in PHYS_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 4. 处理 exp token 的 one-hot 列
for col, values in EXP_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 5. 处理 quantified_delivery
if "quantified_delivery" in df.columns:
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
return df
def get_feature_columns() -> List[str]:
"""获取所有特征列名"""
config = LNPDatasetConfig()
return (
["smiles"]
+ config.comp_cols
+ config.phys_cols
+ config.help_cols
+ config.exp_cols
+ ["quantified_delivery"]
)
@app.command()
def main(
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
output_dir: Path = PROCESSED_DATA_DIR / "cv",
n_folds: int = 5,
seed: int = 42,
amine_col: str = "Amine",
):
"""
基于 Amine 分组进行 Cross-Validation 数据划分
采用类似 scaffold splitting 的思路将相同 Amine 的数据放在同一组
确保训练集和测试集之间没有 Amine 泄露
划分比例约为 train:val:test 3:1:1
处理 cross-validation 数据生成模型所需的 parquet 文件
输出结构:
- processed/cv/fold_0/train.parquet
- processed/cv/fold_0/val.parquet
- processed/cv/fold_0/valid.parquet
- processed/cv/fold_0/test.parquet
- processed/cv/fold_1/...
- processed/cv/feature_columns.txt
"""
logger.info(f"Loading data from {input_path}")
df = pd.read_csv(input_path)
logger.info(f"Loaded {len(df)} samples")
logger.info(f"Processing CV data from {data_dir}")
# 检查 amine 列是否存在
if amine_col not in df.columns:
logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(df.columns)}")
# 获取所有 cv_* 目录
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
if len(cv_dirs) == 0:
logger.error(f"No cv_* directories found in {data_dir}")
raise typer.Exit(1)
# 处理数据列对齐、one-hot 生成等)
logger.info("Processing dataframe...")
df = process_dataframe(df)
if len(cv_dirs) != n_folds:
logger.warning(f"Expected {n_folds} folds, found {len(cv_dirs)}")
# 确保 Amine 列仍然存在process_dataframe 可能不会保留它)
# 重新加载原始数据获取 Amine 列
original_df = pd.read_csv(input_path)
if amine_col in original_df.columns and amine_col not in df.columns:
df[amine_col] = original_df[amine_col].values
logger.info(f"Found {len(cv_dirs)} folds: {[d.name for d in cv_dirs]}")
# 定义要保留的列
phys_cols = get_phys_cols()
exp_cols = get_exp_cols()
keep_cols = (
[SMILES_COL]
+ COMP_COLS
+ phys_cols
+ HELP_COLS
+ exp_cols
+ TARGET_REGRESSION
+ TARGET_CLASSIFICATION_PDI
+ TARGET_CLASSIFICATION_EE
+ [TARGET_TOXIC]
+ TARGET_BIODIST
)
# 只保留存在的列
keep_cols = [c for c in keep_cols if c in df.columns]
# 进行 CV 划分
logger.info(f"\nPerforming {n_folds}-fold amine-based CV split (seed={seed})...")
fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col)
# 保存每个 fold
feature_cols = get_feature_columns()
output_dir.mkdir(parents=True, exist_ok=True)
for i, split in enumerate(fold_splits):
fold_dir = output_dir / f"fold_{i}"
fold_dir.mkdir(parents=True, exist_ok=True)
for i, cv_dir in enumerate(cv_dirs):
fold_name = f"fold_{i}"
fold_output_dir = output_dir / fold_name
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Processing {cv_dir.name} -> {fold_name}")
# 加载数据
train_df, valid_df, test_df = load_cv_split(cv_dir)
logger.info(f" Loaded: train={len(train_df)}, valid={len(valid_df)}, test={len(test_df)}")
# 处理数据
train_df = process_cv_dataframe(train_df)
valid_df = process_cv_dataframe(valid_df)
test_df = process_cv_dataframe(test_df)
# 确保所有列存在
for col in feature_cols:
for df in [train_df, valid_df, test_df]:
if col not in df.columns:
df[col] = 0.0 if col != "smiles" else ""
# 只保留需要的列
train_df = split["train"][keep_cols].reset_index(drop=True)
val_df = split["val"][keep_cols].reset_index(drop=True)
test_df = split["test"][keep_cols].reset_index(drop=True)
train_df = train_df[feature_cols]
valid_df = valid_df[feature_cols]
test_df = test_df[feature_cols]
# 保存
train_df.to_parquet(fold_dir / "train.parquet", index=False)
val_df.to_parquet(fold_dir / "val.parquet", index=False)
test_df.to_parquet(fold_dir / "test.parquet", index=False)
train_df.to_parquet(fold_output_dir / "train.parquet", index=False)
valid_df.to_parquet(fold_output_dir / "valid.parquet", index=False)
test_df.to_parquet(fold_output_dir / "test.parquet", index=False)
logger.success(f"Saved fold {i} to {fold_dir}")
logger.success(f" Saved to {fold_output_dir}")
# 保存列名配置
config_path = output_dir / "feature_columns.txt"
with open(config_path, "w") as f:
f.write("# Feature columns configuration\n\n")
f.write(f"# SMILES\n{SMILES_COL}\n\n")
f.write(f"# comp token [{len(COMP_COLS)}]\n")
f.write("\n".join(COMP_COLS) + "\n\n")
f.write(f"# phys token [{len(phys_cols)}]\n")
f.write("\n".join(phys_cols) + "\n\n")
f.write(f"# help token [{len(HELP_COLS)}]\n")
f.write("\n".join(HELP_COLS) + "\n\n")
f.write(f"# exp token [{len(exp_cols)}]\n")
f.write("\n".join(exp_cols) + "\n\n")
f.write("# Targets\n")
f.write("## Regression\n")
f.write("\n".join(TARGET_REGRESSION) + "\n")
f.write("## PDI classification\n")
f.write("\n".join(TARGET_CLASSIFICATION_PDI) + "\n")
f.write("## EE classification\n")
f.write("\n".join(TARGET_CLASSIFICATION_EE) + "\n")
f.write("## Toxic\n")
f.write(f"{TARGET_TOXIC}\n")
f.write("## Biodistribution\n")
f.write("\n".join(TARGET_BIODIST) + "\n")
# 保存特征列配置
cols_path = output_dir / "feature_columns.txt"
with open(cols_path, "w") as f:
f.write("\n".join(feature_cols))
logger.success(f"Saved feature columns to {cols_path}")
logger.success(f"Saved feature config to {config_path}")
# 打印汇总
logger.info("\n" + "=" * 60)
logger.info("CV DATA PROCESSING COMPLETE")
logger.info("=" * 60)
logger.info(f"Output directory: {output_dir}")
logger.info(f"Number of folds: {n_folds}")
logger.info(f"Splitting method: Amine-based (column: {amine_col})")
logger.info(f"Random seed: {seed}")
logger.info(f"Number of folds: {len(cv_dirs)}")
if __name__ == "__main__":

View File

@ -1,234 +0,0 @@
"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式"""
from pathlib import Path
from typing import Dict, List, Tuple
import pandas as pd
import typer
from loguru import logger
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
from lnp_ml.dataset import (
LNPDatasetConfig,
COMP_COLS,
HELP_COLS,
get_phys_cols,
get_exp_cols,
EXP_ONEHOT_SPECS,
PHYS_ONEHOT_SPECS,
)
app = typer.Typer()
# CV extra_x 列名到模型列名的映射
CV_COL_MAPPING = {
# Batch_or_individual_or_barcoded -> Sample_organization_type (for Value_name related)
"Batch_or_individual_or_barcoded_Barcoded": "Batch_or_individual_or_barcoded_Barcoded",
"Batch_or_individual_or_barcoded_Individual": "Batch_or_individual_or_barcoded_Individual",
# Helper_lipid_ID_None 不在模型中使用,忽略
}
def load_cv_split(cv_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
加载单个 CV split 的数据
Args:
cv_dir: CV split 目录 cv_0/
Returns:
(train_df, valid_df, test_df) 合并后的 DataFrame
"""
splits = {}
for split_name in ["train", "valid", "test"]:
# 加载主数据smiles, quantified_delivery
main_path = cv_dir / f"{split_name}.csv"
extra_x_path = cv_dir / f"{split_name}_extra_x.csv"
metadata_path = cv_dir / f"{split_name}_metadata.csv"
if not main_path.exists():
raise FileNotFoundError(f"Missing {main_path}")
main_df = pd.read_csv(main_path)
# 加载 extra_x已 one-hot 编码的特征)
if extra_x_path.exists():
extra_x_df = pd.read_csv(extra_x_path)
# 确保行数一致
assert len(main_df) == len(extra_x_df), f"Row count mismatch: {len(main_df)} vs {len(extra_x_df)}"
# 合并(按行索引)
df = pd.concat([main_df, extra_x_df], axis=1)
else:
df = main_df
logger.warning(f" {split_name}_extra_x.csv not found, using main data only")
# 可选:从 metadata 获取额外信息
if metadata_path.exists():
metadata_df = pd.read_csv(metadata_path)
# 提取需要的列(如 Purity, Mix_type, Value_name 等)
for col in ["Purity", "Mix_type", "Value_name", "Target_or_delivered_gene"]:
if col in metadata_df.columns and col not in df.columns:
df[col] = metadata_df[col]
splits[split_name] = df
return splits["train"], splits["valid"], splits["test"]
def process_cv_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""
处理 CV 数据的 DataFrame对齐到模型所需的列格式
CV 数据的 extra_x 已经包含大部分 one-hot 编码但需要
1. 添加缺失的 one-hot 设为 0
2. metadata 中生成 phys token one-hot Purity, Mix_type, Cargo_type, Target_or_delivered_gene
3. 生成 Value_name one-hot
"""
df = df.copy()
# 1. 处理 comp 列
for col in COMP_COLS:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
else:
df[col] = 0.0
# 2. 处理 help 列
for col in HELP_COLS:
if col not in df.columns:
df[col] = 0.0
else:
df[col] = df[col].fillna(0.0).astype(float)
# 3. 处理 phys token 的 one-hot 列
for col, values in PHYS_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 4. 处理 exp token 的 one-hot 列
for col, values in EXP_ONEHOT_SPECS.items():
for v in values:
onehot_col = f"{col}_{v}"
if onehot_col not in df.columns:
# 尝试从原始列生成
if col in df.columns:
df[onehot_col] = (df[col] == v).astype(float)
else:
df[onehot_col] = 0.0
else:
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
# 5. 处理 quantified_delivery
if "quantified_delivery" in df.columns:
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
return df
def get_feature_columns() -> List[str]:
"""获取所有特征列名"""
config = LNPDatasetConfig()
return (
["smiles"]
+ config.comp_cols
+ config.phys_cols
+ config.help_cols
+ config.exp_cols
+ ["quantified_delivery"]
)
@app.command()
def main(
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
output_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
n_folds: int = 5,
):
"""
处理 cross-validation 数据生成模型所需的 parquet 文件
输出结构:
- processed/pretrain_cv/fold_0/train.parquet
- processed/pretrain_cv/fold_0/valid.parquet
- processed/pretrain_cv/fold_0/test.parquet
- processed/pretrain_cv/fold_1/...
- processed/pretrain_cv/feature_columns.txt
"""
logger.info(f"Processing CV data from {data_dir}")
# 获取所有 cv_* 目录
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
if len(cv_dirs) == 0:
logger.error(f"No cv_* directories found in {data_dir}")
raise typer.Exit(1)
if len(cv_dirs) != n_folds:
logger.warning(f"Expected {n_folds} folds, found {len(cv_dirs)}")
logger.info(f"Found {len(cv_dirs)} folds: {[d.name for d in cv_dirs]}")
feature_cols = get_feature_columns()
output_dir.mkdir(parents=True, exist_ok=True)
for i, cv_dir in enumerate(cv_dirs):
fold_name = f"fold_{i}"
fold_output_dir = output_dir / fold_name
fold_output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Processing {cv_dir.name} -> {fold_name}")
# 加载数据
train_df, valid_df, test_df = load_cv_split(cv_dir)
logger.info(f" Loaded: train={len(train_df)}, valid={len(valid_df)}, test={len(test_df)}")
# 处理数据
train_df = process_cv_dataframe(train_df)
valid_df = process_cv_dataframe(valid_df)
test_df = process_cv_dataframe(test_df)
# 确保所有列存在
for col in feature_cols:
for df in [train_df, valid_df, test_df]:
if col not in df.columns:
df[col] = 0.0 if col != "smiles" else ""
# 只保留需要的列
train_df = train_df[feature_cols]
valid_df = valid_df[feature_cols]
test_df = test_df[feature_cols]
# 保存
train_df.to_parquet(fold_output_dir / "train.parquet", index=False)
valid_df.to_parquet(fold_output_dir / "valid.parquet", index=False)
test_df.to_parquet(fold_output_dir / "test.parquet", index=False)
logger.success(f" Saved to {fold_output_dir}")
# 保存特征列配置
cols_path = output_dir / "feature_columns.txt"
with open(cols_path, "w") as f:
f.write("\n".join(feature_cols))
logger.success(f"Saved feature columns to {cols_path}")
logger.info("\n" + "=" * 60)
logger.info("CV DATA PROCESSING COMPLETE")
logger.info("=" * 60)
logger.info(f"Output directory: {output_dir}")
logger.info(f"Number of folds: {len(cv_dirs)}")
if __name__ == "__main__":
app()