mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
Compare commits
4 Commits
e123fc8f3e
...
ac4246c2b7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac4246c2b7 | ||
|
|
47bbb64c66 | ||
|
|
039be54c5a | ||
|
|
e6a5e5495a |
25
Makefile
25
Makefile
@ -74,6 +74,11 @@ data_pretrain: requirements
|
||||
$(PYTHON_INTERPRETER) scripts/process_external.py
|
||||
|
||||
## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv)
|
||||
.PHONY: data_pretrain_cv
|
||||
data_pretrain_cv: requirements
|
||||
$(PYTHON_INTERPRETER) scripts/process_external_cv.py
|
||||
|
||||
## Process internal data with amine-based CV splitting (interim -> processed/cv)
|
||||
.PHONY: data_cv
|
||||
data_cv: requirements
|
||||
$(PYTHON_INTERPRETER) scripts/process_data_cv.py
|
||||
@ -106,8 +111,8 @@ pretrain_cv: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
## Evaluate CV pretrain models on test sets (auto-detects MPNN from checkpoint)
|
||||
.PHONY: test_cv
|
||||
test_cv: requirements
|
||||
.PHONY: test_pretrain_cv
|
||||
test_pretrain_cv: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(DEVICE_FLAG)
|
||||
|
||||
## Train model (multi-task, from scratch)
|
||||
@ -120,6 +125,22 @@ train: requirements
|
||||
finetune: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
## Finetune with cross-validation on internal data (5-fold, amine-based split) with pretrained weights
|
||||
.PHONY: finetune_cv
|
||||
finetune_cv: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv main --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
## Train with cross-validation on internal data only (5-fold, amine-based split)
|
||||
.PHONY: train_cv
|
||||
train_cv: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv main $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG)
|
||||
|
||||
|
||||
## Evaluate CV finetuned models on test sets (auto-detects MPNN from checkpoint)
|
||||
.PHONY: test_cv
|
||||
test_cv: requirements
|
||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv test $(DEVICE_FLAG)
|
||||
|
||||
## Train with hyperparameter tuning
|
||||
.PHONY: tune
|
||||
tune: requirements
|
||||
|
||||
@ -1,9 +1,16 @@
|
||||
# Feature columns configuration
|
||||
|
||||
# SMILES
|
||||
smiles
|
||||
|
||||
# comp token [5]
|
||||
Cationic_Lipid_to_mRNA_weight_ratio
|
||||
Cationic_Lipid_Mol_Ratio
|
||||
Phospholipid_Mol_Ratio
|
||||
Cholesterol_Mol_Ratio
|
||||
PEG_Lipid_Mol_Ratio
|
||||
|
||||
# phys token [12]
|
||||
Purity_Pure
|
||||
Purity_Crude
|
||||
Mix_type_Microfluidic
|
||||
@ -16,10 +23,14 @@ Target_or_delivered_gene_Peptide_barcode
|
||||
Target_or_delivered_gene_hEPO
|
||||
Target_or_delivered_gene_FVII
|
||||
Target_or_delivered_gene_GFP
|
||||
|
||||
# help token [4]
|
||||
Helper_lipid_ID_DOPE
|
||||
Helper_lipid_ID_DOTAP
|
||||
Helper_lipid_ID_DSPC
|
||||
Helper_lipid_ID_MDOA
|
||||
|
||||
# exp token [32]
|
||||
Model_type_A549
|
||||
Model_type_BDMC
|
||||
Model_type_BMDM
|
||||
@ -52,4 +63,27 @@ Value_name_hEPO
|
||||
Value_name_FVII_silencing
|
||||
Value_name_GFP_delivery
|
||||
Value_name_Discretized_luminescence
|
||||
quantified_delivery
|
||||
|
||||
# Targets
|
||||
## Regression
|
||||
size
|
||||
quantified_delivery
|
||||
## PDI classification
|
||||
PDI_0_0to0_2
|
||||
PDI_0_2to0_3
|
||||
PDI_0_3to0_4
|
||||
PDI_0_4to0_5
|
||||
## EE classification
|
||||
Encapsulation_Efficiency_EE<50
|
||||
Encapsulation_Efficiency_50<=EE<80
|
||||
Encapsulation_Efficiency_80<EE<=100
|
||||
## Toxic
|
||||
toxic
|
||||
## Biodistribution
|
||||
Biodistribution_lymph_nodes
|
||||
Biodistribution_heart
|
||||
Biodistribution_liver
|
||||
Biodistribution_spleen
|
||||
Biodistribution_lung
|
||||
Biodistribution_kidney
|
||||
Biodistribution_muscle
|
||||
|
||||
Binary file not shown.
Binary file not shown.
BIN
data/processed/cv/fold_0/val.parquet
Normal file
BIN
data/processed/cv/fold_0/val.parquet
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
data/processed/cv/fold_1/val.parquet
Normal file
BIN
data/processed/cv/fold_1/val.parquet
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
data/processed/cv/fold_2/val.parquet
Normal file
BIN
data/processed/cv/fold_2/val.parquet
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
data/processed/cv/fold_3/val.parquet
Normal file
BIN
data/processed/cv/fold_3/val.parquet
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
data/processed/cv/fold_4/val.parquet
Normal file
BIN
data/processed/cv/fold_4/val.parquet
Normal file
Binary file not shown.
55
data/processed/pretrain_cv/feature_columns.txt
Normal file
55
data/processed/pretrain_cv/feature_columns.txt
Normal file
@ -0,0 +1,55 @@
|
||||
smiles
|
||||
Cationic_Lipid_to_mRNA_weight_ratio
|
||||
Cationic_Lipid_Mol_Ratio
|
||||
Phospholipid_Mol_Ratio
|
||||
Cholesterol_Mol_Ratio
|
||||
PEG_Lipid_Mol_Ratio
|
||||
Purity_Pure
|
||||
Purity_Crude
|
||||
Mix_type_Microfluidic
|
||||
Mix_type_Pipetting
|
||||
Cargo_type_mRNA
|
||||
Cargo_type_pDNA
|
||||
Cargo_type_siRNA
|
||||
Target_or_delivered_gene_FFL
|
||||
Target_or_delivered_gene_Peptide_barcode
|
||||
Target_or_delivered_gene_hEPO
|
||||
Target_or_delivered_gene_FVII
|
||||
Target_or_delivered_gene_GFP
|
||||
Helper_lipid_ID_DOPE
|
||||
Helper_lipid_ID_DOTAP
|
||||
Helper_lipid_ID_DSPC
|
||||
Helper_lipid_ID_MDOA
|
||||
Model_type_A549
|
||||
Model_type_BDMC
|
||||
Model_type_BMDM
|
||||
Model_type_HBEC_ALI
|
||||
Model_type_HEK293T
|
||||
Model_type_HeLa
|
||||
Model_type_IGROV1
|
||||
Model_type_Mouse
|
||||
Model_type_RAW264p7
|
||||
Delivery_target_body
|
||||
Delivery_target_dendritic_cell
|
||||
Delivery_target_generic_cell
|
||||
Delivery_target_liver
|
||||
Delivery_target_lung
|
||||
Delivery_target_lung_epithelium
|
||||
Delivery_target_macrophage
|
||||
Delivery_target_muscle
|
||||
Delivery_target_spleen
|
||||
Route_of_administration_in_vitro
|
||||
Route_of_administration_intramuscular
|
||||
Route_of_administration_intratracheal
|
||||
Route_of_administration_intravenous
|
||||
Batch_or_individual_or_barcoded_Barcoded
|
||||
Batch_or_individual_or_barcoded_Individual
|
||||
Value_name_log_luminescence
|
||||
Value_name_luminescence
|
||||
Value_name_FFL_silencing
|
||||
Value_name_Peptide_abundance
|
||||
Value_name_hEPO
|
||||
Value_name_FVII_silencing
|
||||
Value_name_GFP_delivery
|
||||
Value_name_Discretized_luminescence
|
||||
quantified_delivery
|
||||
BIN
data/processed/pretrain_cv/fold_0/train.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_0/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/pretrain_cv/fold_1/test.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_1/test.parquet
Normal file
Binary file not shown.
BIN
data/processed/pretrain_cv/fold_1/train.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_1/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/pretrain_cv/fold_2/test.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_2/test.parquet
Normal file
Binary file not shown.
BIN
data/processed/pretrain_cv/fold_2/train.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_2/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/pretrain_cv/fold_3/test.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_3/test.parquet
Normal file
Binary file not shown.
BIN
data/processed/pretrain_cv/fold_3/train.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_3/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/pretrain_cv/fold_4/test.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_4/test.parquet
Normal file
Binary file not shown.
BIN
data/processed/pretrain_cv/fold_4/train.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_4/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/pretrain_cv/fold_4/valid.parquet
Normal file
BIN
data/processed/pretrain_cv/fold_4/valid.parquet
Normal file
Binary file not shown.
@ -1,37 +1,51 @@
|
||||
{
|
||||
"loss_metrics": {
|
||||
"loss": 2.8661977450052896,
|
||||
"loss_size": 0.44916408757368725,
|
||||
"loss_pdi": 0.5041926403840383,
|
||||
"loss_ee": 0.9021427234013876,
|
||||
"loss_delivery": 0.5761533578236898,
|
||||
"loss_biodist": 0.4019051690896352,
|
||||
"loss_toxic": 0.03263980595511384,
|
||||
"acc_pdi": 0.7633587786259542,
|
||||
"acc_ee": 0.6641221374045801,
|
||||
"acc_toxic": 0.9702970297029703
|
||||
"loss": 2.5374555587768555,
|
||||
"loss_size": 0.1886825958887736,
|
||||
"loss_pdi": 0.45798932512601215,
|
||||
"loss_ee": 0.829658567905426,
|
||||
"loss_delivery": 0.4857304096221924,
|
||||
"loss_biodist": 0.5346279243628184,
|
||||
"loss_toxic": 0.04076674363265435,
|
||||
"acc_pdi": 0.7862595419847328,
|
||||
"acc_ee": 0.6793893129770993,
|
||||
"acc_toxic": 0.9801980198019802
|
||||
},
|
||||
"detailed_metrics": {
|
||||
"size": {
|
||||
"mse": 0.41126506251447736,
|
||||
"rmse": 0.6412995107704959,
|
||||
"mae": 0.41415552388095633,
|
||||
"r2": -0.9333718010891026
|
||||
"mse": 0.1669999969286325,
|
||||
"rmse": 0.4086563310761654,
|
||||
"mae": 0.26111859684375066,
|
||||
"r2": 0.2149270281561566
|
||||
},
|
||||
"delivery": {
|
||||
"mse": 0.6277965050686476,
|
||||
"rmse": 0.7923361061245711,
|
||||
"mae": 0.5387302115022443,
|
||||
"r2": 0.24206702565575944
|
||||
"mse": 0.5193460523366603,
|
||||
"rmse": 0.7206566813238189,
|
||||
"mae": 0.4828052782115008,
|
||||
"r2": 0.37299826459145
|
||||
},
|
||||
"pdi": {
|
||||
"accuracy": 0.7633587786259542
|
||||
"accuracy": 0.7862595419847328,
|
||||
"precision": 0.7282763532763532,
|
||||
"recall": 0.6907738095238095,
|
||||
"f1": 0.7041935483870968
|
||||
},
|
||||
"ee": {
|
||||
"accuracy": 0.6641221374045801
|
||||
"accuracy": 0.6793893129770993,
|
||||
"precision": 0.612247574088644,
|
||||
"recall": 0.6062951496388029,
|
||||
"f1": 0.6069449904342585
|
||||
},
|
||||
"toxic": {
|
||||
"accuracy": 0.9702970297029703
|
||||
"accuracy": 0.9801980198019802,
|
||||
"precision": 0.5,
|
||||
"recall": 0.4900990099009901,
|
||||
"f1": 0.495
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 101,
|
||||
"kl_divergence": 0.2931957937514963,
|
||||
"js_divergence": 0.07706768601895059
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -217,15 +217,31 @@ def test(
|
||||
"""
|
||||
import json
|
||||
import numpy as np
|
||||
from scipy.special import rel_entr
|
||||
from sklearn.metrics import (
|
||||
mean_squared_error,
|
||||
mean_absolute_error,
|
||||
r2_score,
|
||||
accuracy_score,
|
||||
classification_report,
|
||||
precision_score,
|
||||
recall_score,
|
||||
f1_score,
|
||||
)
|
||||
from lnp_ml.modeling.trainer import validate
|
||||
|
||||
def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
|
||||
"""计算 KL 散度 KL(p || q)"""
|
||||
p = np.clip(p, eps, 1.0)
|
||||
q = np.clip(q, eps, 1.0)
|
||||
return float(np.sum(rel_entr(p, q), axis=-1).mean())
|
||||
|
||||
def js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
|
||||
"""计算 JS 散度"""
|
||||
p = np.clip(p, eps, 1.0)
|
||||
q = np.clip(q, eps, 1.0)
|
||||
m = 0.5 * (p + q)
|
||||
return float(0.5 * (np.sum(rel_entr(p, m), axis=-1) + np.sum(rel_entr(q, m), axis=-1)).mean())
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
device_obj = torch.device(device)
|
||||
|
||||
@ -287,6 +303,9 @@ def test(
|
||||
y_pred = np.array(predictions["pdi"])[mask]
|
||||
results["detailed_metrics"]["pdi"] = {
|
||||
"accuracy": float(accuracy_score(y_true, y_pred)),
|
||||
"precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)),
|
||||
"recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)),
|
||||
"f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
|
||||
}
|
||||
|
||||
# 分类指标:EE
|
||||
@ -299,6 +318,9 @@ def test(
|
||||
y_pred = np.array(predictions["ee"])[mask]
|
||||
results["detailed_metrics"]["ee"] = {
|
||||
"accuracy": float(accuracy_score(y_true, y_pred)),
|
||||
"precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)),
|
||||
"recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)),
|
||||
"f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
|
||||
}
|
||||
|
||||
# 分类指标:toxic
|
||||
@ -309,6 +331,28 @@ def test(
|
||||
y_pred = np.array(predictions["toxic"])[mask.values]
|
||||
results["detailed_metrics"]["toxic"] = {
|
||||
"accuracy": float(accuracy_score(y_true, y_pred)),
|
||||
"precision": float(precision_score(y_true, y_pred, average="macro", zero_division=0)),
|
||||
"recall": float(recall_score(y_true, y_pred, average="macro", zero_division=0)),
|
||||
"f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
|
||||
}
|
||||
|
||||
# 分布指标:biodist
|
||||
biodist_cols = [
|
||||
"Biodistribution_lymph_nodes", "Biodistribution_heart", "Biodistribution_liver",
|
||||
"Biodistribution_spleen", "Biodistribution_lung", "Biodistribution_kidney", "Biodistribution_muscle"
|
||||
]
|
||||
if all(c in test_df.columns for c in biodist_cols):
|
||||
biodist_true = test_df[biodist_cols].values
|
||||
biodist_pred = np.array(predictions["biodist"])
|
||||
# mask: 有效样本是 sum > 0 且无 NaN
|
||||
mask = (biodist_true.sum(axis=1) > 0) & (~np.isnan(biodist_true).any(axis=1))
|
||||
if mask.any():
|
||||
y_true = biodist_true[mask]
|
||||
y_pred = biodist_pred[mask]
|
||||
results["detailed_metrics"]["biodist"] = {
|
||||
"n_samples": int(mask.sum()),
|
||||
"kl_divergence": kl_divergence(y_true, y_pred),
|
||||
"js_divergence": js_divergence(y_true, y_pred),
|
||||
}
|
||||
|
||||
# 打印结果
|
||||
|
||||
@ -271,7 +271,7 @@ def create_model(
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
data_dir: Path = PROCESSED_DATA_DIR / "cv",
|
||||
data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
|
||||
output_dir: Path = MODELS_DIR / "pretrain_cv",
|
||||
# 模型参数
|
||||
d_model: int = 256,
|
||||
@ -322,7 +322,7 @@ def main(
|
||||
|
||||
if not fold_dirs:
|
||||
logger.error(f"No fold_* directories found in {data_dir}")
|
||||
logger.info("Please run 'make data_cv' first to process CV data.")
|
||||
logger.info("Please run 'make data_pretrain_cv' first to process CV data.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
|
||||
@ -464,7 +464,7 @@ def main(
|
||||
|
||||
@app.command()
|
||||
def test(
|
||||
data_dir: Path = PROCESSED_DATA_DIR / "cv",
|
||||
data_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
|
||||
model_dir: Path = MODELS_DIR / "pretrain_cv",
|
||||
output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json",
|
||||
batch_size: int = 64,
|
||||
|
||||
720
lnp_ml/modeling/train_cv.py
Normal file
720
lnp_ml/modeling/train_cv.py
Normal file
@ -0,0 +1,720 @@
|
||||
"""Cross-Validation 训练脚本:在 5-fold 内部数据上进行多任务训练"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
import typer
|
||||
|
||||
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||
from lnp_ml.dataset import LNPDataset, collate_fn
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
from lnp_ml.modeling.trainer import (
|
||||
train_epoch,
|
||||
validate,
|
||||
EarlyStopping,
|
||||
LossWeights,
|
||||
)
|
||||
|
||||
|
||||
# MPNN ensemble 默认路径
|
||||
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||
|
||||
|
||||
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
|
||||
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||
if not model_paths:
|
||||
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||
return [str(p) for p in model_paths]
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def create_model(
|
||||
d_model: int = 256,
|
||||
num_heads: int = 8,
|
||||
n_attn_layers: int = 4,
|
||||
fusion_strategy: str = "attention",
|
||||
head_hidden_dim: int = 128,
|
||||
dropout: float = 0.1,
|
||||
mpnn_checkpoint: Optional[str] = None,
|
||||
mpnn_ensemble_paths: Optional[List[str]] = None,
|
||||
mpnn_device: str = "cpu",
|
||||
) -> Union[LNPModel, LNPModelWithoutMPNN]:
|
||||
"""创建模型(支持可选的 MPNN encoder)"""
|
||||
use_mpnn = mpnn_checkpoint is not None or mpnn_ensemble_paths is not None
|
||||
|
||||
if use_mpnn:
|
||||
return LNPModel(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
mpnn_checkpoint=mpnn_checkpoint,
|
||||
mpnn_ensemble_paths=mpnn_ensemble_paths,
|
||||
mpnn_device=mpnn_device,
|
||||
)
|
||||
else:
|
||||
return LNPModelWithoutMPNN(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
def train_fold(
|
||||
fold_idx: int,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
model: nn.Module,
|
||||
device: torch.device,
|
||||
output_dir: Path,
|
||||
lr: float = 1e-4,
|
||||
weight_decay: float = 1e-5,
|
||||
epochs: int = 100,
|
||||
patience: int = 15,
|
||||
loss_weights: Optional[LossWeights] = None,
|
||||
config: Optional[Dict] = None,
|
||||
) -> Dict:
|
||||
"""训练单个 fold"""
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"Training Fold {fold_idx}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
model = model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="min", factor=0.5, patience=5
|
||||
)
|
||||
early_stopping = EarlyStopping(patience=patience)
|
||||
|
||||
history = {"train": [], "val": []}
|
||||
best_val_loss = float("inf")
|
||||
best_state = None
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Train
|
||||
train_metrics = train_epoch(model, train_loader, optimizer, device, loss_weights)
|
||||
|
||||
# Validate
|
||||
val_metrics = validate(model, val_loader, device, loss_weights)
|
||||
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# Log
|
||||
logger.info(
|
||||
f"Fold {fold_idx} Epoch {epoch+1}/{epochs} | "
|
||||
f"Train Loss: {train_metrics['loss']:.4f} | "
|
||||
f"Val Loss: {val_metrics['loss']:.4f} | "
|
||||
f"LR: {current_lr:.2e}"
|
||||
)
|
||||
|
||||
history["train"].append(train_metrics)
|
||||
history["val"].append(val_metrics)
|
||||
|
||||
# Learning rate scheduling
|
||||
scheduler.step(val_metrics["loss"])
|
||||
|
||||
# Save best model
|
||||
if val_metrics["loss"] < best_val_loss:
|
||||
best_val_loss = val_metrics["loss"]
|
||||
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||
logger.info(f" -> New best model (val_loss={best_val_loss:.4f})")
|
||||
|
||||
# Early stopping
|
||||
if early_stopping(val_metrics["loss"]):
|
||||
logger.info(f"Early stopping at epoch {epoch+1}")
|
||||
break
|
||||
|
||||
# 保存最佳模型
|
||||
fold_output_dir = output_dir / f"fold_{fold_idx}"
|
||||
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
checkpoint_path = fold_output_dir / "model.pt"
|
||||
torch.save({
|
||||
"model_state_dict": best_state,
|
||||
"config": config,
|
||||
"best_val_loss": best_val_loss,
|
||||
"fold_idx": fold_idx,
|
||||
}, checkpoint_path)
|
||||
logger.success(f"Saved fold {fold_idx} model to {checkpoint_path}")
|
||||
|
||||
# 保存训练历史
|
||||
history_path = fold_output_dir / "history.json"
|
||||
with open(history_path, "w") as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
return {
|
||||
"fold_idx": fold_idx,
|
||||
"best_val_loss": best_val_loss,
|
||||
"epochs_trained": len(history["train"]),
|
||||
"final_train_loss": history["train"][-1]["loss"] if history["train"] else 0,
|
||||
}
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
data_dir: Path = PROCESSED_DATA_DIR / "cv",
|
||||
output_dir: Path = MODELS_DIR / "finetune_cv",
|
||||
# 模型参数
|
||||
d_model: int = 256,
|
||||
num_heads: int = 8,
|
||||
n_attn_layers: int = 4,
|
||||
fusion_strategy: str = "attention",
|
||||
head_hidden_dim: int = 128,
|
||||
dropout: float = 0.1,
|
||||
# MPNN 参数(可选)
|
||||
use_mpnn: bool = False,
|
||||
mpnn_checkpoint: Optional[str] = None,
|
||||
mpnn_ensemble_paths: Optional[str] = None,
|
||||
mpnn_device: str = "cpu",
|
||||
# 训练参数
|
||||
batch_size: int = 32,
|
||||
lr: float = 1e-4,
|
||||
weight_decay: float = 1e-5,
|
||||
epochs: int = 100,
|
||||
patience: int = 15,
|
||||
# 预训练权重加载
|
||||
init_from_pretrain: Optional[Path] = None,
|
||||
load_delivery_head: bool = True,
|
||||
freeze_backbone: bool = False,
|
||||
# 设备
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
):
|
||||
"""
|
||||
基于 Cross-Validation 训练 LNP 模型(多任务)。
|
||||
|
||||
在 5-fold 内部数据上训练 5 个模型。
|
||||
|
||||
使用 --use-mpnn 启用 MPNN encoder。
|
||||
使用 --init-from-pretrain 从预训练 checkpoint 初始化。
|
||||
使用 --freeze-backbone 冻结 backbone,只训练 heads。
|
||||
"""
|
||||
logger.info(f"Using device: {device}")
|
||||
device = torch.device(device)
|
||||
|
||||
# 查找所有 fold 目录
|
||||
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
|
||||
|
||||
if not fold_dirs:
|
||||
logger.error(f"No fold_* directories found in {data_dir}")
|
||||
logger.info("Please run 'make data_cv' first to process CV data.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 解析 MPNN 配置
|
||||
ensemble_paths_list = None
|
||||
if mpnn_ensemble_paths:
|
||||
ensemble_paths_list = mpnn_ensemble_paths.split(",")
|
||||
elif use_mpnn and mpnn_checkpoint is None:
|
||||
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
|
||||
ensemble_paths_list = find_mpnn_ensemble_paths()
|
||||
logger.info(f"Found {len(ensemble_paths_list)} MPNN models")
|
||||
|
||||
enable_mpnn = mpnn_checkpoint is not None or ensemble_paths_list is not None
|
||||
|
||||
# 模型配置
|
||||
config = {
|
||||
"d_model": d_model,
|
||||
"num_heads": num_heads,
|
||||
"n_attn_layers": n_attn_layers,
|
||||
"fusion_strategy": fusion_strategy,
|
||||
"head_hidden_dim": head_hidden_dim,
|
||||
"dropout": dropout,
|
||||
"use_mpnn": enable_mpnn,
|
||||
"lr": lr,
|
||||
"weight_decay": weight_decay,
|
||||
"batch_size": batch_size,
|
||||
"epochs": epochs,
|
||||
"patience": patience,
|
||||
"init_from_pretrain": str(init_from_pretrain) if init_from_pretrain else None,
|
||||
"freeze_backbone": freeze_backbone,
|
||||
}
|
||||
|
||||
# 保存配置
|
||||
config_path = output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
logger.info(f"Saved config to {config_path}")
|
||||
|
||||
# 加载预训练权重(如果指定)
|
||||
pretrain_state = None
|
||||
if init_from_pretrain is not None:
|
||||
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
|
||||
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
|
||||
pretrain_config = checkpoint.get("config", {})
|
||||
if pretrain_config.get("d_model") != d_model:
|
||||
logger.warning(
|
||||
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
|
||||
f"current={d_model}. Skipping pretrain loading."
|
||||
)
|
||||
else:
|
||||
pretrain_state = checkpoint["model_state_dict"]
|
||||
|
||||
# 训练每个 fold
|
||||
fold_results = []
|
||||
|
||||
for fold_dir in tqdm(fold_dirs, desc="Training folds"):
|
||||
fold_idx = int(fold_dir.name.split("_")[1])
|
||||
|
||||
# 加载数据
|
||||
train_df = pd.read_parquet(fold_dir / "train.parquet")
|
||||
val_df = pd.read_parquet(fold_dir / "val.parquet")
|
||||
|
||||
logger.info(f"\nFold {fold_idx}: train={len(train_df)}, val={len(val_df)}")
|
||||
|
||||
# 创建 Dataset 和 DataLoader
|
||||
train_dataset = LNPDataset(train_df)
|
||||
val_dataset = LNPDataset(val_df)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# 创建新模型(每个 fold 独立初始化)
|
||||
model = create_model(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
n_attn_layers=n_attn_layers,
|
||||
fusion_strategy=fusion_strategy,
|
||||
head_hidden_dim=head_hidden_dim,
|
||||
dropout=dropout,
|
||||
mpnn_checkpoint=mpnn_checkpoint,
|
||||
mpnn_ensemble_paths=ensemble_paths_list,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
|
||||
# 加载预训练权重
|
||||
if pretrain_state is not None:
|
||||
model.load_pretrain_weights(
|
||||
pretrain_state_dict=pretrain_state,
|
||||
load_delivery_head=load_delivery_head,
|
||||
strict=False,
|
||||
)
|
||||
logger.info(f"Loaded pretrain weights (backbone + delivery_head={load_delivery_head})")
|
||||
|
||||
# 冻结 backbone(如果指定)
|
||||
if freeze_backbone:
|
||||
frozen_count = 0
|
||||
for name, param in model.named_parameters():
|
||||
if name.startswith(("token_projector.", "cross_attention.", "fusion.")):
|
||||
param.requires_grad = False
|
||||
frozen_count += 1
|
||||
logger.info(f"Frozen {frozen_count} parameter tensors")
|
||||
|
||||
# 打印模型信息(仅第一个 fold)
|
||||
if fold_idx == 0:
|
||||
n_params_total = sum(p.numel() for p in model.parameters())
|
||||
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
|
||||
|
||||
# 训练
|
||||
result = train_fold(
|
||||
fold_idx=fold_idx,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
model=model,
|
||||
device=device,
|
||||
output_dir=output_dir,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
epochs=epochs,
|
||||
patience=patience,
|
||||
config=config,
|
||||
)
|
||||
fold_results.append(result)
|
||||
|
||||
# 汇总结果
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("CROSS-VALIDATION TRAINING COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
|
||||
val_losses = [r["best_val_loss"] for r in fold_results]
|
||||
|
||||
logger.info(f"\n[Per-Fold Results]")
|
||||
for r in fold_results:
|
||||
logger.info(
|
||||
f" Fold {r['fold_idx']}: "
|
||||
f"Val Loss={r['best_val_loss']:.4f}, "
|
||||
f"Epochs={r['epochs_trained']}"
|
||||
)
|
||||
|
||||
logger.info(f"\n[Summary Statistics]")
|
||||
logger.info(f" Val Loss: {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}")
|
||||
|
||||
# 保存 CV 结果
|
||||
cv_results = {
|
||||
"fold_results": fold_results,
|
||||
"summary": {
|
||||
"val_loss_mean": float(np.mean(val_losses)),
|
||||
"val_loss_std": float(np.std(val_losses)),
|
||||
},
|
||||
"config": config,
|
||||
}
|
||||
|
||||
results_path = output_dir / "cv_results.json"
|
||||
with open(results_path, "w") as f:
|
||||
json.dump(cv_results, f, indent=2)
|
||||
logger.success(f"Saved CV results to {results_path}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def test(
|
||||
data_dir: Path = PROCESSED_DATA_DIR / "cv",
|
||||
model_dir: Path = MODELS_DIR / "finetune_cv",
|
||||
output_path: Path = MODELS_DIR / "finetune_cv" / "test_results.json",
|
||||
batch_size: int = 64,
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||
):
|
||||
"""
|
||||
在测试集上评估 CV 训练的模型。
|
||||
|
||||
使用每个 fold 的模型在对应的测试集上评估,然后汇总结果。
|
||||
"""
|
||||
from scipy.special import rel_entr
|
||||
from sklearn.metrics import (
|
||||
mean_squared_error,
|
||||
mean_absolute_error,
|
||||
r2_score,
|
||||
accuracy_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
f1_score,
|
||||
)
|
||||
|
||||
def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
|
||||
"""计算 KL 散度 KL(p || q)"""
|
||||
p = np.clip(p, eps, 1.0)
|
||||
q = np.clip(q, eps, 1.0)
|
||||
return float(np.sum(rel_entr(p, q), axis=-1).mean())
|
||||
|
||||
def js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-10) -> float:
|
||||
"""计算 JS 散度"""
|
||||
p = np.clip(p, eps, 1.0)
|
||||
q = np.clip(q, eps, 1.0)
|
||||
m = 0.5 * (p + q)
|
||||
return float(0.5 * (np.sum(rel_entr(p, m), axis=-1) + np.sum(rel_entr(q, m), axis=-1)).mean())
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
device = torch.device(device)
|
||||
|
||||
# 查找所有 fold 目录
|
||||
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
|
||||
|
||||
if not fold_dirs:
|
||||
logger.error(f"No fold_* directories found in {data_dir}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
logger.info(f"Found {len(fold_dirs)} folds")
|
||||
|
||||
fold_results = []
|
||||
# 用于汇总所有 fold 的预测
|
||||
all_preds = {
|
||||
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
|
||||
}
|
||||
all_targets = {
|
||||
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
|
||||
}
|
||||
|
||||
for fold_dir in tqdm(fold_dirs, desc="Evaluating folds"):
|
||||
fold_idx = int(fold_dir.name.split("_")[1])
|
||||
model_path = model_dir / f"fold_{fold_idx}" / "model.pt"
|
||||
test_path = fold_dir / "test.parquet"
|
||||
|
||||
if not model_path.exists():
|
||||
logger.warning(f"Fold {fold_idx}: model not found at {model_path}, skipping")
|
||||
continue
|
||||
|
||||
if not test_path.exists():
|
||||
logger.warning(f"Fold {fold_idx}: test data not found at {test_path}, skipping")
|
||||
continue
|
||||
|
||||
# 加载模型
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
config = checkpoint["config"]
|
||||
|
||||
use_mpnn = config.get("use_mpnn", False)
|
||||
|
||||
# 总是重新查找 MPNN 路径
|
||||
if use_mpnn:
|
||||
mpnn_paths = find_mpnn_ensemble_paths()
|
||||
else:
|
||||
mpnn_paths = None
|
||||
|
||||
model = create_model(
|
||||
d_model=config["d_model"],
|
||||
num_heads=config["num_heads"],
|
||||
n_attn_layers=config["n_attn_layers"],
|
||||
fusion_strategy=config["fusion_strategy"],
|
||||
head_hidden_dim=config["head_hidden_dim"],
|
||||
dropout=config["dropout"],
|
||||
mpnn_ensemble_paths=mpnn_paths,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# 加载测试数据
|
||||
test_df = pd.read_parquet(test_path)
|
||||
test_dataset = LNPDataset(test_df)
|
||||
test_loader = DataLoader(
|
||||
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# 收集当前 fold 的预测
|
||||
fold_preds = {k: [] for k in all_preds.keys()}
|
||||
fold_targets = {k: [] for k in all_targets.keys()}
|
||||
|
||||
with torch.no_grad():
|
||||
pbar = tqdm(test_loader, desc=f"Fold {fold_idx} [Test]", leave=False)
|
||||
for batch in pbar:
|
||||
smiles = batch["smiles"]
|
||||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||
targets = batch["targets"]
|
||||
masks = batch["mask"]
|
||||
|
||||
outputs = model(smiles, tabular)
|
||||
|
||||
# Size
|
||||
if "size" in masks and masks["size"].any():
|
||||
mask = masks["size"]
|
||||
fold_preds["size"].extend(
|
||||
outputs["size"].squeeze(-1)[mask].cpu().numpy().tolist()
|
||||
)
|
||||
fold_targets["size"].extend(
|
||||
targets["size"][mask].cpu().numpy().tolist()
|
||||
)
|
||||
|
||||
# Delivery
|
||||
if "delivery" in masks and masks["delivery"].any():
|
||||
mask = masks["delivery"]
|
||||
fold_preds["delivery"].extend(
|
||||
outputs["delivery"].squeeze(-1)[mask].cpu().numpy().tolist()
|
||||
)
|
||||
fold_targets["delivery"].extend(
|
||||
targets["delivery"][mask].cpu().numpy().tolist()
|
||||
)
|
||||
|
||||
# PDI (classification)
|
||||
if "pdi" in masks and masks["pdi"].any():
|
||||
mask = masks["pdi"]
|
||||
pdi_preds = outputs["pdi"][mask].argmax(dim=-1).cpu().numpy()
|
||||
pdi_targets = targets["pdi"][mask].cpu().numpy()
|
||||
fold_preds["pdi"].extend(pdi_preds.tolist())
|
||||
fold_targets["pdi"].extend(pdi_targets.tolist())
|
||||
|
||||
# EE (classification)
|
||||
if "ee" in masks and masks["ee"].any():
|
||||
mask = masks["ee"]
|
||||
ee_preds = outputs["ee"][mask].argmax(dim=-1).cpu().numpy()
|
||||
ee_targets = targets["ee"][mask].cpu().numpy()
|
||||
fold_preds["ee"].extend(ee_preds.tolist())
|
||||
fold_targets["ee"].extend(ee_targets.tolist())
|
||||
|
||||
# Toxic (classification)
|
||||
if "toxic" in masks and masks["toxic"].any():
|
||||
mask = masks["toxic"]
|
||||
toxic_preds = outputs["toxic"][mask].argmax(dim=-1).cpu().numpy()
|
||||
toxic_targets = targets["toxic"][mask].cpu().numpy().astype(int)
|
||||
fold_preds["toxic"].extend(toxic_preds.tolist())
|
||||
fold_targets["toxic"].extend(toxic_targets.tolist())
|
||||
|
||||
# Biodist (distribution)
|
||||
if "biodist" in masks and masks["biodist"].any():
|
||||
mask = masks["biodist"]
|
||||
biodist_preds = outputs["biodist"][mask].cpu().numpy()
|
||||
biodist_targets = targets["biodist"][mask].cpu().numpy()
|
||||
fold_preds["biodist"].extend(biodist_preds.tolist())
|
||||
fold_targets["biodist"].extend(biodist_targets.tolist())
|
||||
|
||||
# 计算当前 fold 的指标
|
||||
fold_metrics = {"fold_idx": fold_idx, "n_samples": len(test_df)}
|
||||
|
||||
# 回归任务指标
|
||||
for task in ["size", "delivery"]:
|
||||
if fold_preds[task]:
|
||||
p = np.array(fold_preds[task])
|
||||
t = np.array(fold_targets[task])
|
||||
fold_metrics[task] = {
|
||||
"n": len(p),
|
||||
"rmse": float(np.sqrt(mean_squared_error(t, p))),
|
||||
"mae": float(mean_absolute_error(t, p)),
|
||||
"r2": float(r2_score(t, p)),
|
||||
}
|
||||
|
||||
# 分类任务指标
|
||||
for task in ["pdi", "ee", "toxic"]:
|
||||
if fold_preds[task]:
|
||||
p = np.array(fold_preds[task])
|
||||
t = np.array(fold_targets[task])
|
||||
fold_metrics[task] = {
|
||||
"n": len(p),
|
||||
"accuracy": float(accuracy_score(t, p)),
|
||||
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
|
||||
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
|
||||
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
|
||||
}
|
||||
|
||||
# 分布任务指标
|
||||
if fold_preds["biodist"]:
|
||||
p = np.array(fold_preds["biodist"])
|
||||
t = np.array(fold_targets["biodist"])
|
||||
fold_metrics["biodist"] = {
|
||||
"n": len(p),
|
||||
"kl_divergence": kl_divergence(t, p),
|
||||
"js_divergence": js_divergence(t, p),
|
||||
}
|
||||
|
||||
fold_results.append(fold_metrics)
|
||||
|
||||
# 汇总到全局
|
||||
for task in all_preds.keys():
|
||||
all_preds[task].extend(fold_preds[task])
|
||||
all_targets[task].extend(fold_targets[task])
|
||||
|
||||
# 打印当前 fold 结果
|
||||
log_parts = [f"Fold {fold_idx}: n={len(test_df)}"]
|
||||
for task in ["delivery", "size"]:
|
||||
if task in fold_metrics and isinstance(fold_metrics[task], dict):
|
||||
log_parts.append(f"{task}_RMSE={fold_metrics[task]['rmse']:.4f}")
|
||||
log_parts.append(f"{task}_R²={fold_metrics[task]['r2']:.4f}")
|
||||
for task in ["pdi", "ee", "toxic"]:
|
||||
if task in fold_metrics and isinstance(fold_metrics[task], dict):
|
||||
log_parts.append(f"{task}_acc={fold_metrics[task]['accuracy']:.4f}")
|
||||
log_parts.append(f"{task}_f1={fold_metrics[task]['f1']:.4f}")
|
||||
if "biodist" in fold_metrics and isinstance(fold_metrics["biodist"], dict):
|
||||
log_parts.append(f"biodist_KL={fold_metrics['biodist']['kl_divergence']:.4f}")
|
||||
log_parts.append(f"biodist_JS={fold_metrics['biodist']['js_divergence']:.4f}")
|
||||
logger.info(", ".join(log_parts))
|
||||
|
||||
# 计算跨 fold 汇总统计
|
||||
summary_stats = {}
|
||||
for task in ["size", "delivery"]:
|
||||
rmses = [r[task]["rmse"] for r in fold_results if task in r and isinstance(r[task], dict)]
|
||||
r2s = [r[task]["r2"] for r in fold_results if task in r and isinstance(r[task], dict)]
|
||||
if rmses:
|
||||
summary_stats[task] = {
|
||||
"rmse_mean": float(np.mean(rmses)),
|
||||
"rmse_std": float(np.std(rmses)),
|
||||
"r2_mean": float(np.mean(r2s)),
|
||||
"r2_std": float(np.std(r2s)),
|
||||
}
|
||||
|
||||
for task in ["pdi", "ee", "toxic"]:
|
||||
accs = [r[task]["accuracy"] for r in fold_results if task in r and isinstance(r[task], dict)]
|
||||
f1s = [r[task]["f1"] for r in fold_results if task in r and isinstance(r[task], dict)]
|
||||
if accs:
|
||||
summary_stats[task] = {
|
||||
"accuracy_mean": float(np.mean(accs)),
|
||||
"accuracy_std": float(np.std(accs)),
|
||||
"f1_mean": float(np.mean(f1s)),
|
||||
"f1_std": float(np.std(f1s)),
|
||||
}
|
||||
|
||||
# 分布任务汇总
|
||||
kls = [r["biodist"]["kl_divergence"] for r in fold_results if "biodist" in r and isinstance(r["biodist"], dict)]
|
||||
jss = [r["biodist"]["js_divergence"] for r in fold_results if "biodist" in r and isinstance(r["biodist"], dict)]
|
||||
if kls:
|
||||
summary_stats["biodist"] = {
|
||||
"kl_mean": float(np.mean(kls)),
|
||||
"kl_std": float(np.std(kls)),
|
||||
"js_mean": float(np.mean(jss)),
|
||||
"js_std": float(np.std(jss)),
|
||||
}
|
||||
|
||||
# 计算整体 pooled 指标
|
||||
overall = {}
|
||||
for task in ["size", "delivery"]:
|
||||
if all_preds[task]:
|
||||
p = np.array(all_preds[task])
|
||||
t = np.array(all_targets[task])
|
||||
overall[task] = {
|
||||
"n_samples": len(p),
|
||||
"mse": float(mean_squared_error(t, p)),
|
||||
"rmse": float(np.sqrt(mean_squared_error(t, p))),
|
||||
"mae": float(mean_absolute_error(t, p)),
|
||||
"r2": float(r2_score(t, p)),
|
||||
}
|
||||
|
||||
for task in ["pdi", "ee", "toxic"]:
|
||||
if all_preds[task]:
|
||||
p = np.array(all_preds[task])
|
||||
t = np.array(all_targets[task])
|
||||
overall[task] = {
|
||||
"n_samples": len(p),
|
||||
"accuracy": float(accuracy_score(t, p)),
|
||||
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
|
||||
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
|
||||
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
|
||||
}
|
||||
|
||||
# 分布任务
|
||||
if all_preds["biodist"]:
|
||||
p = np.array(all_preds["biodist"])
|
||||
t = np.array(all_targets["biodist"])
|
||||
overall["biodist"] = {
|
||||
"n_samples": len(p),
|
||||
"kl_divergence": kl_divergence(t, p),
|
||||
"js_divergence": js_divergence(t, p),
|
||||
}
|
||||
|
||||
# 打印汇总结果
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("CV TEST EVALUATION RESULTS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
logger.info(f"\n[Summary Statistics (across {len(fold_results)} folds)]")
|
||||
for task, stats in summary_stats.items():
|
||||
if "rmse_mean" in stats:
|
||||
logger.info(f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}")
|
||||
elif "accuracy_mean" in stats:
|
||||
logger.info(f" {task}: Accuracy={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}, F1={stats['f1_mean']:.4f}±{stats['f1_std']:.4f}")
|
||||
elif "kl_mean" in stats:
|
||||
logger.info(f" {task}: KL={stats['kl_mean']:.4f}±{stats['kl_std']:.4f}, JS={stats['js_mean']:.4f}±{stats['js_std']:.4f}")
|
||||
|
||||
logger.info(f"\n[Overall (all samples pooled)]")
|
||||
for task, metrics in overall.items():
|
||||
if "rmse" in metrics:
|
||||
logger.info(f" {task} (n={metrics['n_samples']}): RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, R²={metrics['r2']:.4f}")
|
||||
elif "accuracy" in metrics:
|
||||
logger.info(f" {task} (n={metrics['n_samples']}): Accuracy={metrics['accuracy']:.4f}, Precision={metrics['precision']:.4f}, Recall={metrics['recall']:.4f}, F1={metrics['f1']:.4f}")
|
||||
elif "kl_divergence" in metrics:
|
||||
logger.info(f" {task} (n={metrics['n_samples']}): KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}")
|
||||
|
||||
# 保存结果
|
||||
results = {
|
||||
"fold_results": fold_results,
|
||||
"summary_stats": summary_stats,
|
||||
"overall": overall,
|
||||
}
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
logger.success(f"\nSaved test results to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
16
models/finetune_cv/config.json
Normal file
16
models/finetune_cv/config.json
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"d_model": 256,
|
||||
"num_heads": 8,
|
||||
"n_attn_layers": 4,
|
||||
"fusion_strategy": "attention",
|
||||
"head_hidden_dim": 128,
|
||||
"dropout": 0.1,
|
||||
"use_mpnn": true,
|
||||
"lr": 0.0001,
|
||||
"weight_decay": 1e-05,
|
||||
"batch_size": 32,
|
||||
"epochs": 100,
|
||||
"patience": 15,
|
||||
"init_from_pretrain": null,
|
||||
"freeze_backbone": false
|
||||
}
|
||||
54
models/finetune_cv/cv_results.json
Normal file
54
models/finetune_cv/cv_results.json
Normal file
@ -0,0 +1,54 @@
|
||||
{
|
||||
"fold_results": [
|
||||
{
|
||||
"fold_idx": 0,
|
||||
"best_val_loss": 5.7676777839660645,
|
||||
"epochs_trained": 24,
|
||||
"final_train_loss": 1.4942118644714355
|
||||
},
|
||||
{
|
||||
"fold_idx": 1,
|
||||
"best_val_loss": 8.418675899505615,
|
||||
"epochs_trained": 20,
|
||||
"final_train_loss": 1.4902493238449097
|
||||
},
|
||||
{
|
||||
"fold_idx": 2,
|
||||
"best_val_loss": 3.5122547830854143,
|
||||
"epochs_trained": 25,
|
||||
"final_train_loss": 1.7609570423762004
|
||||
},
|
||||
{
|
||||
"fold_idx": 3,
|
||||
"best_val_loss": 3.165306806564331,
|
||||
"epochs_trained": 21,
|
||||
"final_train_loss": 2.0073827385902403
|
||||
},
|
||||
{
|
||||
"fold_idx": 4,
|
||||
"best_val_loss": 2.996154228846232,
|
||||
"epochs_trained": 18,
|
||||
"final_train_loss": 1.9732873006300493
|
||||
}
|
||||
],
|
||||
"summary": {
|
||||
"val_loss_mean": 4.772013900393532,
|
||||
"val_loss_std": 2.0790222989111475
|
||||
},
|
||||
"config": {
|
||||
"d_model": 256,
|
||||
"num_heads": 8,
|
||||
"n_attn_layers": 4,
|
||||
"fusion_strategy": "attention",
|
||||
"head_hidden_dim": 128,
|
||||
"dropout": 0.1,
|
||||
"use_mpnn": true,
|
||||
"lr": 0.0001,
|
||||
"weight_decay": 1e-05,
|
||||
"batch_size": 32,
|
||||
"epochs": 100,
|
||||
"patience": 15,
|
||||
"init_from_pretrain": null,
|
||||
"freeze_backbone": false
|
||||
}
|
||||
}
|
||||
510
models/finetune_cv/fold_0/history.json
Normal file
510
models/finetune_cv/fold_0/history.json
Normal file
@ -0,0 +1,510 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 17.65872812271118,
|
||||
"loss_size": 12.601411867141724,
|
||||
"loss_pdi": 1.3666706204414367,
|
||||
"loss_ee": 1.0830313920974732,
|
||||
"loss_delivery": 0.5962779104709626,
|
||||
"loss_biodist": 1.3918164849281311,
|
||||
"loss_toxic": 0.6195200622081757
|
||||
},
|
||||
{
|
||||
"loss": 5.925264883041382,
|
||||
"loss_size": 1.8580878481268883,
|
||||
"loss_pdi": 1.1011681258678436,
|
||||
"loss_ee": 0.971046245098114,
|
||||
"loss_delivery": 0.5075224950909615,
|
||||
"loss_biodist": 1.1051940202713013,
|
||||
"loss_toxic": 0.38224617540836336
|
||||
},
|
||||
{
|
||||
"loss": 3.4781792640686033,
|
||||
"loss_size": 0.23610344529151917,
|
||||
"loss_pdi": 0.8137399554252625,
|
||||
"loss_ee": 0.9135127127170563,
|
||||
"loss_delivery": 0.4596045270562172,
|
||||
"loss_biodist": 0.8695587992668152,
|
||||
"loss_toxic": 0.18565986081957817
|
||||
},
|
||||
{
|
||||
"loss": 2.9488561868667604,
|
||||
"loss_size": 0.23130029290914536,
|
||||
"loss_pdi": 0.644479614496231,
|
||||
"loss_ee": 0.8721524059772492,
|
||||
"loss_delivery": 0.4146773874759674,
|
||||
"loss_biodist": 0.646893310546875,
|
||||
"loss_toxic": 0.13935319259762763
|
||||
},
|
||||
{
|
||||
"loss": 2.6432241678237913,
|
||||
"loss_size": 0.16843259893357754,
|
||||
"loss_pdi": 0.5857123643159866,
|
||||
"loss_ee": 0.8315786123275757,
|
||||
"loss_delivery": 0.4049036353826523,
|
||||
"loss_biodist": 0.5410242855548859,
|
||||
"loss_toxic": 0.11157271154224872
|
||||
},
|
||||
{
|
||||
"loss": 2.461507487297058,
|
||||
"loss_size": 0.18602822050452233,
|
||||
"loss_pdi": 0.5872043997049332,
|
||||
"loss_ee": 0.8179578661918641,
|
||||
"loss_delivery": 0.32779163047671317,
|
||||
"loss_biodist": 0.45097417533397677,
|
||||
"loss_toxic": 0.09155115596950054
|
||||
},
|
||||
{
|
||||
"loss": 2.3792370796203612,
|
||||
"loss_size": 0.2090120367705822,
|
||||
"loss_pdi": 0.5358257800340652,
|
||||
"loss_ee": 0.8088949501514435,
|
||||
"loss_delivery": 0.3434994474053383,
|
||||
"loss_biodist": 0.40993946194648745,
|
||||
"loss_toxic": 0.07206540685147048
|
||||
},
|
||||
{
|
||||
"loss": 2.207099366188049,
|
||||
"loss_size": 0.1589151345193386,
|
||||
"loss_pdi": 0.5283154606819153,
|
||||
"loss_ee": 0.7723551869392395,
|
||||
"loss_delivery": 0.35645291954278946,
|
||||
"loss_biodist": 0.3404483631253242,
|
||||
"loss_toxic": 0.05061229532584548
|
||||
},
|
||||
{
|
||||
"loss": 2.1428971529006957,
|
||||
"loss_size": 0.19335013553500174,
|
||||
"loss_pdi": 0.5021985083818435,
|
||||
"loss_ee": 0.7642539083957672,
|
||||
"loss_delivery": 0.31821031123399734,
|
||||
"loss_biodist": 0.32588216066360476,
|
||||
"loss_toxic": 0.03900211993604898
|
||||
},
|
||||
{
|
||||
"loss": 1.9874909400939942,
|
||||
"loss_size": 0.1736245721578598,
|
||||
"loss_pdi": 0.46206980347633364,
|
||||
"loss_ee": 0.7373365700244904,
|
||||
"loss_delivery": 0.29703493416309357,
|
||||
"loss_biodist": 0.2863417714834213,
|
||||
"loss_toxic": 0.031083252932876348
|
||||
},
|
||||
{
|
||||
"loss": 1.9297520160675048,
|
||||
"loss_size": 0.1635374441742897,
|
||||
"loss_pdi": 0.4737923800945282,
|
||||
"loss_ee": 0.7171129584312439,
|
||||
"loss_delivery": 0.28808903992176055,
|
||||
"loss_biodist": 0.25874830335378646,
|
||||
"loss_toxic": 0.028471904620528222
|
||||
},
|
||||
{
|
||||
"loss": 1.8647576332092286,
|
||||
"loss_size": 0.14790172204375268,
|
||||
"loss_pdi": 0.4427785277366638,
|
||||
"loss_ee": 0.7089932143688202,
|
||||
"loss_delivery": 0.30143058970570563,
|
||||
"loss_biodist": 0.24234647750854493,
|
||||
"loss_toxic": 0.021307120053097605
|
||||
},
|
||||
{
|
||||
"loss": 1.7996623039245605,
|
||||
"loss_size": 0.1429538145661354,
|
||||
"loss_pdi": 0.45114057660102846,
|
||||
"loss_ee": 0.681770408153534,
|
||||
"loss_delivery": 0.2735618159174919,
|
||||
"loss_biodist": 0.2338838443160057,
|
||||
"loss_toxic": 0.01635184111073613
|
||||
},
|
||||
{
|
||||
"loss": 1.7303769707679748,
|
||||
"loss_size": 0.13725369721651076,
|
||||
"loss_pdi": 0.43492600619792937,
|
||||
"loss_ee": 0.6648448914289474,
|
||||
"loss_delivery": 0.2714417055249214,
|
||||
"loss_biodist": 0.20898159295320512,
|
||||
"loss_toxic": 0.012929048202931882
|
||||
},
|
||||
{
|
||||
"loss": 1.702065145969391,
|
||||
"loss_size": 0.1783118523657322,
|
||||
"loss_pdi": 0.4118753671646118,
|
||||
"loss_ee": 0.640222480893135,
|
||||
"loss_delivery": 0.2610591858625412,
|
||||
"loss_biodist": 0.20058825612068176,
|
||||
"loss_toxic": 0.01000797227025032
|
||||
},
|
||||
{
|
||||
"loss": 1.6243244886398316,
|
||||
"loss_size": 0.1371393844485283,
|
||||
"loss_pdi": 0.3978125751018524,
|
||||
"loss_ee": 0.6315451622009277,
|
||||
"loss_delivery": 0.2618463449180126,
|
||||
"loss_biodist": 0.18574777096509934,
|
||||
"loss_toxic": 0.010233237966895103
|
||||
},
|
||||
{
|
||||
"loss": 1.645119547843933,
|
||||
"loss_size": 0.13622624576091766,
|
||||
"loss_pdi": 0.4013118803501129,
|
||||
"loss_ee": 0.639850401878357,
|
||||
"loss_delivery": 0.2615354858338833,
|
||||
"loss_biodist": 0.19717498123645782,
|
||||
"loss_toxic": 0.009020529384724797
|
||||
},
|
||||
{
|
||||
"loss": 1.5792422771453858,
|
||||
"loss_size": 0.12063037976622581,
|
||||
"loss_pdi": 0.40477685928344725,
|
||||
"loss_ee": 0.6168571084737777,
|
||||
"loss_delivery": 0.23877703920006751,
|
||||
"loss_biodist": 0.1887524366378784,
|
||||
"loss_toxic": 0.009448455832898616
|
||||
},
|
||||
{
|
||||
"loss": 1.5701380014419555,
|
||||
"loss_size": 0.12370488420128822,
|
||||
"loss_pdi": 0.3944096490740776,
|
||||
"loss_ee": 0.6204680263996124,
|
||||
"loss_delivery": 0.2499392546713352,
|
||||
"loss_biodist": 0.1741167649626732,
|
||||
"loss_toxic": 0.00749938020016998
|
||||
},
|
||||
{
|
||||
"loss": 1.5445807576179504,
|
||||
"loss_size": 0.12085893377661705,
|
||||
"loss_pdi": 0.4022176057100296,
|
||||
"loss_ee": 0.6029386401176453,
|
||||
"loss_delivery": 0.2460342638194561,
|
||||
"loss_biodist": 0.16601160615682603,
|
||||
"loss_toxic": 0.006519717467017472
|
||||
},
|
||||
{
|
||||
"loss": 1.4764926195144654,
|
||||
"loss_size": 0.11393929794430732,
|
||||
"loss_pdi": 0.3614879995584488,
|
||||
"loss_ee": 0.5874974340200424,
|
||||
"loss_delivery": 0.2382828861474991,
|
||||
"loss_biodist": 0.168075630068779,
|
||||
"loss_toxic": 0.00720936032012105
|
||||
},
|
||||
{
|
||||
"loss": 1.4663256525993347,
|
||||
"loss_size": 0.10480817258358002,
|
||||
"loss_pdi": 0.3699364930391312,
|
||||
"loss_ee": 0.591068571805954,
|
||||
"loss_delivery": 0.23481545299291612,
|
||||
"loss_biodist": 0.1582734301686287,
|
||||
"loss_toxic": 0.007423530006781221
|
||||
},
|
||||
{
|
||||
"loss": 1.4797919273376465,
|
||||
"loss_size": 0.11906521767377853,
|
||||
"loss_pdi": 0.3831163257360458,
|
||||
"loss_ee": 0.5810098886489868,
|
||||
"loss_delivery": 0.22465722858905793,
|
||||
"loss_biodist": 0.16469249799847602,
|
||||
"loss_toxic": 0.007250743336044252
|
||||
},
|
||||
{
|
||||
"loss": 1.4942118644714355,
|
||||
"loss_size": 0.11249525547027588,
|
||||
"loss_pdi": 0.3718418627977371,
|
||||
"loss_ee": 0.5973137259483338,
|
||||
"loss_delivery": 0.23963096588850022,
|
||||
"loss_biodist": 0.16598810032010078,
|
||||
"loss_toxic": 0.006941930414177478
|
||||
}
|
||||
],
|
||||
"val": [
|
||||
{
|
||||
"loss": 13.683866500854492,
|
||||
"loss_size": 5.657964706420898,
|
||||
"loss_pdi": 1.1590962409973145,
|
||||
"loss_ee": 1.0155898332595825,
|
||||
"loss_delivery": 4.1429033279418945,
|
||||
"loss_biodist": 1.128843069076538,
|
||||
"loss_toxic": 0.579468846321106,
|
||||
"acc_pdi": 0.7407407407407407,
|
||||
"acc_ee": 0.6296296296296297,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 7.161172866821289,
|
||||
"loss_size": 0.1799931526184082,
|
||||
"loss_pdi": 0.8303115963935852,
|
||||
"loss_ee": 0.942605197429657,
|
||||
"loss_delivery": 3.986294984817505,
|
||||
"loss_biodist": 1.022797703742981,
|
||||
"loss_toxic": 0.19917015731334686,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5925925925925926,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.554836273193359,
|
||||
"loss_size": 0.04609166830778122,
|
||||
"loss_pdi": 0.4924769997596741,
|
||||
"loss_ee": 0.965587317943573,
|
||||
"loss_delivery": 3.978637933731079,
|
||||
"loss_biodist": 1.0135102272033691,
|
||||
"loss_toxic": 0.05853228643536568,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5925925925925926,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.843129634857178,
|
||||
"loss_size": 0.07650057226419449,
|
||||
"loss_pdi": 0.43551138043403625,
|
||||
"loss_ee": 0.9353340864181519,
|
||||
"loss_delivery": 4.557775974273682,
|
||||
"loss_biodist": 0.7909315228462219,
|
||||
"loss_toxic": 0.047075945883989334,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5925925925925926,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.711758613586426,
|
||||
"loss_size": 0.04316325858235359,
|
||||
"loss_pdi": 0.41873815655708313,
|
||||
"loss_ee": 1.0096691846847534,
|
||||
"loss_delivery": 4.517927169799805,
|
||||
"loss_biodist": 0.6788683533668518,
|
||||
"loss_toxic": 0.04339226707816124,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5925925925925926,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.905030250549316,
|
||||
"loss_size": 0.045318666845560074,
|
||||
"loss_pdi": 0.38593801856040955,
|
||||
"loss_ee": 1.0019593238830566,
|
||||
"loss_delivery": 4.807835578918457,
|
||||
"loss_biodist": 0.6247215867042542,
|
||||
"loss_toxic": 0.039257097989320755,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5185185185185185,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.417820930480957,
|
||||
"loss_size": 0.05034356936812401,
|
||||
"loss_pdi": 0.4149726331233978,
|
||||
"loss_ee": 0.9869357943534851,
|
||||
"loss_delivery": 4.405001640319824,
|
||||
"loss_biodist": 0.533240556716919,
|
||||
"loss_toxic": 0.02732720412313938,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5925925925925926,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.608631610870361,
|
||||
"loss_size": 0.05222579464316368,
|
||||
"loss_pdi": 0.4375711679458618,
|
||||
"loss_ee": 1.0041171312332153,
|
||||
"loss_delivery": 4.578192234039307,
|
||||
"loss_biodist": 0.5125234723091125,
|
||||
"loss_toxic": 0.02400212176144123,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5185185185185185,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 5.7676777839660645,
|
||||
"loss_size": 0.09589201211929321,
|
||||
"loss_pdi": 0.3261733949184418,
|
||||
"loss_ee": 0.9482788443565369,
|
||||
"loss_delivery": 3.856112003326416,
|
||||
"loss_biodist": 0.5298716425895691,
|
||||
"loss_toxic": 0.011350298300385475,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5555555555555556,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.920990943908691,
|
||||
"loss_size": 0.05388057231903076,
|
||||
"loss_pdi": 0.39705148339271545,
|
||||
"loss_ee": 0.990842878818512,
|
||||
"loss_delivery": 5.025243282318115,
|
||||
"loss_biodist": 0.4346938133239746,
|
||||
"loss_toxic": 0.019278930500149727,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5555555555555556,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 5.798760890960693,
|
||||
"loss_size": 0.09857960045337677,
|
||||
"loss_pdi": 0.33329641819000244,
|
||||
"loss_ee": 0.9614524245262146,
|
||||
"loss_delivery": 4.000489711761475,
|
||||
"loss_biodist": 0.39874210953712463,
|
||||
"loss_toxic": 0.006200834643095732,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5555555555555556,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.575327396392822,
|
||||
"loss_size": 0.054699063301086426,
|
||||
"loss_pdi": 0.33702051639556885,
|
||||
"loss_ee": 0.9436452388763428,
|
||||
"loss_delivery": 4.817119121551514,
|
||||
"loss_biodist": 0.41582298278808594,
|
||||
"loss_toxic": 0.007020703982561827,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5555555555555556,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 5.989306449890137,
|
||||
"loss_size": 0.09009546041488647,
|
||||
"loss_pdi": 0.3044246733188629,
|
||||
"loss_ee": 1.0130207538604736,
|
||||
"loss_delivery": 4.140576362609863,
|
||||
"loss_biodist": 0.4378862977027893,
|
||||
"loss_toxic": 0.0033026484306901693,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5555555555555556,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.383339881896973,
|
||||
"loss_size": 0.1530081033706665,
|
||||
"loss_pdi": 0.29700207710266113,
|
||||
"loss_ee": 0.9943283796310425,
|
||||
"loss_delivery": 4.564785480499268,
|
||||
"loss_biodist": 0.37085360288619995,
|
||||
"loss_toxic": 0.003362649120390415,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5185185185185185,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.233416557312012,
|
||||
"loss_size": 0.1473817676305771,
|
||||
"loss_pdi": 0.2754640281200409,
|
||||
"loss_ee": 0.9803684949874878,
|
||||
"loss_delivery": 4.443488597869873,
|
||||
"loss_biodist": 0.38424524664878845,
|
||||
"loss_toxic": 0.0024682653602212667,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5925925925925926,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.094257354736328,
|
||||
"loss_size": 0.10127364844083786,
|
||||
"loss_pdi": 0.2960923910140991,
|
||||
"loss_ee": 1.0121080875396729,
|
||||
"loss_delivery": 4.2689008712768555,
|
||||
"loss_biodist": 0.4132467210292816,
|
||||
"loss_toxic": 0.0026356647722423077,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5555555555555556,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.0315470695495605,
|
||||
"loss_size": 0.13236114382743835,
|
||||
"loss_pdi": 0.29554903507232666,
|
||||
"loss_ee": 0.9912998080253601,
|
||||
"loss_delivery": 4.2240777015686035,
|
||||
"loss_biodist": 0.3861069977283478,
|
||||
"loss_toxic": 0.002152827335521579,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5925925925925926,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.13291597366333,
|
||||
"loss_size": 0.10603927820920944,
|
||||
"loss_pdi": 0.30880627036094666,
|
||||
"loss_ee": 1.0417256355285645,
|
||||
"loss_delivery": 4.337818622589111,
|
||||
"loss_biodist": 0.33598482608795166,
|
||||
"loss_toxic": 0.002541647758334875,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5555555555555556,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 5.918347358703613,
|
||||
"loss_size": 0.11423231661319733,
|
||||
"loss_pdi": 0.2779754102230072,
|
||||
"loss_ee": 1.023812174797058,
|
||||
"loss_delivery": 4.137387275695801,
|
||||
"loss_biodist": 0.36283549666404724,
|
||||
"loss_toxic": 0.002104171784594655,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5185185185185185,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.354115962982178,
|
||||
"loss_size": 0.1212025135755539,
|
||||
"loss_pdi": 0.2848753333091736,
|
||||
"loss_ee": 1.031553030014038,
|
||||
"loss_delivery": 4.554471969604492,
|
||||
"loss_biodist": 0.3598195016384125,
|
||||
"loss_toxic": 0.0021939175203442574,
|
||||
"acc_pdi": 0.8518518518518519,
|
||||
"acc_ee": 0.5555555555555556,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 5.881389141082764,
|
||||
"loss_size": 0.1030399352312088,
|
||||
"loss_pdi": 0.2791188657283783,
|
||||
"loss_ee": 1.0205037593841553,
|
||||
"loss_delivery": 4.111578464508057,
|
||||
"loss_biodist": 0.3657107949256897,
|
||||
"loss_toxic": 0.0014369667042046785,
|
||||
"acc_pdi": 0.8518518518518519,
|
||||
"acc_ee": 0.5555555555555556,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.1028852462768555,
|
||||
"loss_size": 0.10241233557462692,
|
||||
"loss_pdi": 0.300007700920105,
|
||||
"loss_ee": 1.0756882429122925,
|
||||
"loss_delivery": 4.258440971374512,
|
||||
"loss_biodist": 0.3646480441093445,
|
||||
"loss_toxic": 0.0016881312476471066,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5185185185185185,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.128824234008789,
|
||||
"loss_size": 0.1437627077102661,
|
||||
"loss_pdi": 0.29325851798057556,
|
||||
"loss_ee": 1.0818182229995728,
|
||||
"loss_delivery": 4.236568450927734,
|
||||
"loss_biodist": 0.3719424605369568,
|
||||
"loss_toxic": 0.001474093529395759,
|
||||
"acc_pdi": 0.8888888888888888,
|
||||
"acc_ee": 0.5185185185185185,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 6.055476188659668,
|
||||
"loss_size": 0.13312266767024994,
|
||||
"loss_pdi": 0.28571468591690063,
|
||||
"loss_ee": 1.066524624824524,
|
||||
"loss_delivery": 4.214193820953369,
|
||||
"loss_biodist": 0.35442692041397095,
|
||||
"loss_toxic": 0.0014939504908397794,
|
||||
"acc_pdi": 0.8518518518518519,
|
||||
"acc_ee": 0.5185185185185185,
|
||||
"acc_toxic": 1.0
|
||||
}
|
||||
]
|
||||
}
|
||||
BIN
models/finetune_cv/fold_0/model.pt
Normal file
BIN
models/finetune_cv/fold_0/model.pt
Normal file
Binary file not shown.
426
models/finetune_cv/fold_1/history.json
Normal file
426
models/finetune_cv/fold_1/history.json
Normal file
@ -0,0 +1,426 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 21.963233947753906,
|
||||
"loss_size": 16.82633171081543,
|
||||
"loss_pdi": 1.2230936765670777,
|
||||
"loss_ee": 1.0703922033309936,
|
||||
"loss_delivery": 1.0690569162368775,
|
||||
"loss_biodist": 1.1534382343292235,
|
||||
"loss_toxic": 0.6209211587905884
|
||||
},
|
||||
{
|
||||
"loss": 13.145495796203614,
|
||||
"loss_size": 8.676862716674805,
|
||||
"loss_pdi": 1.0655134558677672,
|
||||
"loss_ee": 0.8999906063079834,
|
||||
"loss_delivery": 0.8303895950317383,
|
||||
"loss_biodist": 1.122723388671875,
|
||||
"loss_toxic": 0.5500160694122315
|
||||
},
|
||||
{
|
||||
"loss": 7.351448345184326,
|
||||
"loss_size": 3.415665292739868,
|
||||
"loss_pdi": 0.8565655469894409,
|
||||
"loss_ee": 0.7837236523628235,
|
||||
"loss_delivery": 0.8804788589477539,
|
||||
"loss_biodist": 1.011645209789276,
|
||||
"loss_toxic": 0.40336963534355164
|
||||
},
|
||||
{
|
||||
"loss": 4.39948205947876,
|
||||
"loss_size": 0.9713698267936707,
|
||||
"loss_pdi": 0.6989291191101075,
|
||||
"loss_ee": 0.6805540442466735,
|
||||
"loss_delivery": 0.7624839186668396,
|
||||
"loss_biodist": 0.9798830866813659,
|
||||
"loss_toxic": 0.3062621414661407
|
||||
},
|
||||
{
|
||||
"loss": 3.375754451751709,
|
||||
"loss_size": 0.24608666747808455,
|
||||
"loss_pdi": 0.5557448148727417,
|
||||
"loss_ee": 0.6684133768081665,
|
||||
"loss_delivery": 0.7611681580543518,
|
||||
"loss_biodist": 0.919653308391571,
|
||||
"loss_toxic": 0.22468801140785216
|
||||
},
|
||||
{
|
||||
"loss": 2.9307605743408205,
|
||||
"loss_size": 0.1106911577284336,
|
||||
"loss_pdi": 0.5004462003707886,
|
||||
"loss_ee": 0.6227471172809601,
|
||||
"loss_delivery": 0.6758030593395233,
|
||||
"loss_biodist": 0.8190896153450012,
|
||||
"loss_toxic": 0.20198351740837098
|
||||
},
|
||||
{
|
||||
"loss": 2.731675052642822,
|
||||
"loss_size": 0.13740637749433518,
|
||||
"loss_pdi": 0.4836215674877167,
|
||||
"loss_ee": 0.5896897256374359,
|
||||
"loss_delivery": 0.5866121172904968,
|
||||
"loss_biodist": 0.7556124567985535,
|
||||
"loss_toxic": 0.17873288169503213
|
||||
},
|
||||
{
|
||||
"loss": 2.4887039184570314,
|
||||
"loss_size": 0.12009606957435608,
|
||||
"loss_pdi": 0.4361336886882782,
|
||||
"loss_ee": 0.597134268283844,
|
||||
"loss_delivery": 0.5648026138544082,
|
||||
"loss_biodist": 0.6326960444450378,
|
||||
"loss_toxic": 0.13784122765064238
|
||||
},
|
||||
{
|
||||
"loss": 2.1680586099624635,
|
||||
"loss_size": 0.12401954531669616,
|
||||
"loss_pdi": 0.40216060280799865,
|
||||
"loss_ee": 0.5528951227664948,
|
||||
"loss_delivery": 0.42899617552757263,
|
||||
"loss_biodist": 0.5442585527896882,
|
||||
"loss_toxic": 0.1157285787165165
|
||||
},
|
||||
{
|
||||
"loss": 2.1059993267059327,
|
||||
"loss_size": 0.13299092650413513,
|
||||
"loss_pdi": 0.38143277168273926,
|
||||
"loss_ee": 0.5274551689624787,
|
||||
"loss_delivery": 0.47739412933588027,
|
||||
"loss_biodist": 0.4953398108482361,
|
||||
"loss_toxic": 0.0913865402340889
|
||||
},
|
||||
{
|
||||
"loss": 1.9570286750793457,
|
||||
"loss_size": 0.1426382303237915,
|
||||
"loss_pdi": 0.38325140476226804,
|
||||
"loss_ee": 0.49524895548820497,
|
||||
"loss_delivery": 0.42715947031974794,
|
||||
"loss_biodist": 0.4287752747535706,
|
||||
"loss_toxic": 0.07995530962944031
|
||||
},
|
||||
{
|
||||
"loss": 1.8469573497772216,
|
||||
"loss_size": 0.14165955781936646,
|
||||
"loss_pdi": 0.36685559153556824,
|
||||
"loss_ee": 0.4988661766052246,
|
||||
"loss_delivery": 0.36661114990711213,
|
||||
"loss_biodist": 0.39747334718704225,
|
||||
"loss_toxic": 0.07549156174063683
|
||||
},
|
||||
{
|
||||
"loss": 1.6980855226516725,
|
||||
"loss_size": 0.11332993358373641,
|
||||
"loss_pdi": 0.350938493013382,
|
||||
"loss_ee": 0.47553136944770813,
|
||||
"loss_delivery": 0.30049399137496946,
|
||||
"loss_biodist": 0.3953311860561371,
|
||||
"loss_toxic": 0.062460555136203764
|
||||
},
|
||||
{
|
||||
"loss": 1.743706512451172,
|
||||
"loss_size": 0.12467859983444214,
|
||||
"loss_pdi": 0.3706244468688965,
|
||||
"loss_ee": 0.4802402436733246,
|
||||
"loss_delivery": 0.36484516113996507,
|
||||
"loss_biodist": 0.3557030588388443,
|
||||
"loss_toxic": 0.04761496149003506
|
||||
},
|
||||
{
|
||||
"loss": 1.7470735549926757,
|
||||
"loss_size": 0.10215002745389938,
|
||||
"loss_pdi": 0.3553147315979004,
|
||||
"loss_ee": 0.4548905730247498,
|
||||
"loss_delivery": 0.4480485826730728,
|
||||
"loss_biodist": 0.3265932142734528,
|
||||
"loss_toxic": 0.06007647253572941
|
||||
},
|
||||
{
|
||||
"loss": 1.7687433004379272,
|
||||
"loss_size": 0.10528398901224137,
|
||||
"loss_pdi": 0.35497177839279176,
|
||||
"loss_ee": 0.4946293234825134,
|
||||
"loss_delivery": 0.44853600263595583,
|
||||
"loss_biodist": 0.3113987982273102,
|
||||
"loss_toxic": 0.053923492506146434
|
||||
},
|
||||
{
|
||||
"loss": 1.573294997215271,
|
||||
"loss_size": 0.11145550012588501,
|
||||
"loss_pdi": 0.33941014409065245,
|
||||
"loss_ee": 0.42823529839515684,
|
||||
"loss_delivery": 0.34292849004268644,
|
||||
"loss_biodist": 0.3095307767391205,
|
||||
"loss_toxic": 0.04173475466668606
|
||||
},
|
||||
{
|
||||
"loss": 1.482050108909607,
|
||||
"loss_size": 0.13211917281150817,
|
||||
"loss_pdi": 0.31831381320953367,
|
||||
"loss_ee": 0.4258797198534012,
|
||||
"loss_delivery": 0.26612227857112886,
|
||||
"loss_biodist": 0.30344046354293824,
|
||||
"loss_toxic": 0.03617466017603874
|
||||
},
|
||||
{
|
||||
"loss": 1.5079625368118286,
|
||||
"loss_size": 0.1129397764801979,
|
||||
"loss_pdi": 0.3118207275867462,
|
||||
"loss_ee": 0.4255594819784164,
|
||||
"loss_delivery": 0.30544502288103104,
|
||||
"loss_biodist": 0.31328115463256834,
|
||||
"loss_toxic": 0.03891638442873955
|
||||
},
|
||||
{
|
||||
"loss": 1.4902493238449097,
|
||||
"loss_size": 0.09879767149686813,
|
||||
"loss_pdi": 0.333440762758255,
|
||||
"loss_ee": 0.430321592092514,
|
||||
"loss_delivery": 0.3070627197623253,
|
||||
"loss_biodist": 0.28984564244747163,
|
||||
"loss_toxic": 0.030780918896198273
|
||||
}
|
||||
],
|
||||
"val": [
|
||||
{
|
||||
"loss": 24.328961690266926,
|
||||
"loss_size": 15.672358830769857,
|
||||
"loss_pdi": 1.268057902654012,
|
||||
"loss_ee": 1.0569811463356018,
|
||||
"loss_delivery": 4.617272272706032,
|
||||
"loss_biodist": 1.0806464751561482,
|
||||
"loss_toxic": 0.633646289507548,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.5894736842105263,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 17.03301429748535,
|
||||
"loss_size": 8.649629751841227,
|
||||
"loss_pdi": 1.165820797284444,
|
||||
"loss_ee": 0.9437925020853678,
|
||||
"loss_delivery": 4.629274984200795,
|
||||
"loss_biodist": 1.0683060089747112,
|
||||
"loss_toxic": 0.5761909882227579,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 11.635572751363119,
|
||||
"loss_size": 3.504341204961141,
|
||||
"loss_pdi": 1.0855141083399455,
|
||||
"loss_ee": 0.8674407601356506,
|
||||
"loss_delivery": 4.705501407384872,
|
||||
"loss_biodist": 1.0376905004183452,
|
||||
"loss_toxic": 0.43508487939834595,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 9.058362166086832,
|
||||
"loss_size": 1.0461570421854656,
|
||||
"loss_pdi": 1.070031762123108,
|
||||
"loss_ee": 0.8463932275772095,
|
||||
"loss_delivery": 4.791346887747447,
|
||||
"loss_biodist": 0.9781110286712646,
|
||||
"loss_toxic": 0.32632239659627277,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.418675899505615,
|
||||
"loss_size": 0.3764288102587064,
|
||||
"loss_pdi": 1.0916812817255657,
|
||||
"loss_ee": 0.8714254101117452,
|
||||
"loss_delivery": 4.8696667949358625,
|
||||
"loss_biodist": 0.9307892719904581,
|
||||
"loss_toxic": 0.278684730331103,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.51748021443685,
|
||||
"loss_size": 0.33909208327531815,
|
||||
"loss_pdi": 1.103804111480713,
|
||||
"loss_ee": 0.8707688599824905,
|
||||
"loss_delivery": 5.0624091029167175,
|
||||
"loss_biodist": 0.8743396997451782,
|
||||
"loss_toxic": 0.2670666699608167,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.701509237289429,
|
||||
"loss_size": 0.38883806640903157,
|
||||
"loss_pdi": 1.0901564558347066,
|
||||
"loss_ee": 0.8219001442193985,
|
||||
"loss_delivery": 5.329233412941297,
|
||||
"loss_biodist": 0.808117667833964,
|
||||
"loss_toxic": 0.2632630293567975,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.602253516515097,
|
||||
"loss_size": 0.399209912866354,
|
||||
"loss_pdi": 1.035650501648585,
|
||||
"loss_ee": 0.8119546920061111,
|
||||
"loss_delivery": 5.297288862367471,
|
||||
"loss_biodist": 0.8003136416276296,
|
||||
"loss_toxic": 0.2578362462421258,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.610430796941122,
|
||||
"loss_size": 0.3884888291358948,
|
||||
"loss_pdi": 0.9680223266283671,
|
||||
"loss_ee": 0.8063104202349981,
|
||||
"loss_delivery": 5.504999443888664,
|
||||
"loss_biodist": 0.7153328458468119,
|
||||
"loss_toxic": 0.22727691816786924,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.894750118255615,
|
||||
"loss_size": 0.38015256201227504,
|
||||
"loss_pdi": 0.9849910040696462,
|
||||
"loss_ee": 0.8192636320988337,
|
||||
"loss_delivery": 5.8433875640233355,
|
||||
"loss_biodist": 0.6525928874810537,
|
||||
"loss_toxic": 0.21436312049627304,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.684672435124716,
|
||||
"loss_size": 0.39142270882924396,
|
||||
"loss_pdi": 0.9926454623540243,
|
||||
"loss_ee": 0.8487897912661234,
|
||||
"loss_delivery": 5.675399616360664,
|
||||
"loss_biodist": 0.5763055086135864,
|
||||
"loss_toxic": 0.2001086367915074,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.468807538350424,
|
||||
"loss_size": 0.37035099665323895,
|
||||
"loss_pdi": 0.9933059811592102,
|
||||
"loss_ee": 0.8365495651960373,
|
||||
"loss_delivery": 5.39086152613163,
|
||||
"loss_biodist": 0.6555034021536509,
|
||||
"loss_toxic": 0.22223659542699656,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.48760732014974,
|
||||
"loss_size": 0.3547621878484885,
|
||||
"loss_pdi": 1.008083571990331,
|
||||
"loss_ee": 0.8507340376575788,
|
||||
"loss_delivery": 5.329072058200836,
|
||||
"loss_biodist": 0.7051869928836823,
|
||||
"loss_toxic": 0.23976873668531576,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.534255782763163,
|
||||
"loss_size": 0.35214799270033836,
|
||||
"loss_pdi": 1.0083338419596355,
|
||||
"loss_ee": 0.8703259030977885,
|
||||
"loss_delivery": 5.4809657235940294,
|
||||
"loss_biodist": 0.6066243648529053,
|
||||
"loss_toxic": 0.21585797673712173,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6842105263157895,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.59092911084493,
|
||||
"loss_size": 0.3520332872867584,
|
||||
"loss_pdi": 0.9944024880727133,
|
||||
"loss_ee": 0.8839219162861506,
|
||||
"loss_delivery": 5.593439628680547,
|
||||
"loss_biodist": 0.562449519832929,
|
||||
"loss_toxic": 0.20468231476843357,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.581690510114035,
|
||||
"loss_size": 0.3468632685641448,
|
||||
"loss_pdi": 1.0153752664724986,
|
||||
"loss_ee": 0.884696863591671,
|
||||
"loss_delivery": 5.548932209610939,
|
||||
"loss_biodist": 0.576594889163971,
|
||||
"loss_toxic": 0.20922777770708004,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.60028068224589,
|
||||
"loss_size": 0.34553587809205055,
|
||||
"loss_pdi": 1.0314316948254902,
|
||||
"loss_ee": 0.8696443388859431,
|
||||
"loss_delivery": 5.513105024894078,
|
||||
"loss_biodist": 0.6187789390484492,
|
||||
"loss_toxic": 0.2217849005634586,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6947368421052632,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.721842130025228,
|
||||
"loss_size": 0.3432792164385319,
|
||||
"loss_pdi": 1.044082870086034,
|
||||
"loss_ee": 0.8888355021675428,
|
||||
"loss_delivery": 5.590167284011841,
|
||||
"loss_biodist": 0.6300752113262812,
|
||||
"loss_toxic": 0.22540184513976178,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6736842105263158,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.821967244148254,
|
||||
"loss_size": 0.3423520748813947,
|
||||
"loss_pdi": 1.0627215206623077,
|
||||
"loss_ee": 0.9012102037668228,
|
||||
"loss_delivery": 5.6443866689999895,
|
||||
"loss_biodist": 0.6428664823373159,
|
||||
"loss_toxic": 0.2284308553983768,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6631578947368421,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
},
|
||||
{
|
||||
"loss": 8.798149506251017,
|
||||
"loss_size": 0.3439513569076856,
|
||||
"loss_pdi": 1.074403668443362,
|
||||
"loss_ee": 0.9037297517061234,
|
||||
"loss_delivery": 5.62445667386055,
|
||||
"loss_biodist": 0.6279164751370748,
|
||||
"loss_toxic": 0.22369086369872093,
|
||||
"acc_pdi": 0.6105263157894737,
|
||||
"acc_ee": 0.6631578947368421,
|
||||
"acc_toxic": 0.8939393939393939
|
||||
}
|
||||
]
|
||||
}
|
||||
BIN
models/finetune_cv/fold_1/model.pt
Normal file
BIN
models/finetune_cv/fold_1/model.pt
Normal file
Binary file not shown.
531
models/finetune_cv/fold_2/history.json
Normal file
531
models/finetune_cv/fold_2/history.json
Normal file
@ -0,0 +1,531 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 22.972569465637207,
|
||||
"loss_size": 17.22947899500529,
|
||||
"loss_pdi": 1.3510672052701314,
|
||||
"loss_ee": 1.0506827433904011,
|
||||
"loss_delivery": 1.5721548050642014,
|
||||
"loss_biodist": 1.1304322481155396,
|
||||
"loss_toxic": 0.6387530366579691
|
||||
},
|
||||
{
|
||||
"loss": 12.718681335449219,
|
||||
"loss_size": 7.335077285766602,
|
||||
"loss_pdi": 1.198062241077423,
|
||||
"loss_ee": 0.97108127673467,
|
||||
"loss_delivery": 1.639556477467219,
|
||||
"loss_biodist": 1.0746847093105316,
|
||||
"loss_toxic": 0.50021959344546
|
||||
},
|
||||
{
|
||||
"loss": 6.867454210917155,
|
||||
"loss_size": 2.3153140544891357,
|
||||
"loss_pdi": 1.0159071584542592,
|
||||
"loss_ee": 0.853508859872818,
|
||||
"loss_delivery": 1.2862873176733653,
|
||||
"loss_biodist": 1.0416639745235443,
|
||||
"loss_toxic": 0.35477257271607715
|
||||
},
|
||||
{
|
||||
"loss": 4.856432318687439,
|
||||
"loss_size": 0.5409951706727346,
|
||||
"loss_pdi": 0.8652523259321848,
|
||||
"loss_ee": 0.7771940131982168,
|
||||
"loss_delivery": 1.413562481602033,
|
||||
"loss_biodist": 0.9977987806002299,
|
||||
"loss_toxic": 0.26162934054931003
|
||||
},
|
||||
{
|
||||
"loss": 4.253215591112773,
|
||||
"loss_size": 0.2641367167234421,
|
||||
"loss_pdi": 0.739859402179718,
|
||||
"loss_ee": 0.7256686190764109,
|
||||
"loss_delivery": 1.4241955528656642,
|
||||
"loss_biodist": 0.8935903211434683,
|
||||
"loss_toxic": 0.2057649294535319
|
||||
},
|
||||
{
|
||||
"loss": 3.8961705764134726,
|
||||
"loss_size": 0.2962125514944394,
|
||||
"loss_pdi": 0.682400623957316,
|
||||
"loss_ee": 0.6820215880870819,
|
||||
"loss_delivery": 1.2787245536843936,
|
||||
"loss_biodist": 0.7812575101852417,
|
||||
"loss_toxic": 0.1755537080268065
|
||||
},
|
||||
{
|
||||
"loss": 3.4790991942087808,
|
||||
"loss_size": 0.3047281603018443,
|
||||
"loss_pdi": 0.6409291823705038,
|
||||
"loss_ee": 0.6178905169169108,
|
||||
"loss_delivery": 1.1121559316913288,
|
||||
"loss_biodist": 0.6434484819571177,
|
||||
"loss_toxic": 0.15994682783881822
|
||||
},
|
||||
{
|
||||
"loss": 3.2075613339742026,
|
||||
"loss_size": 0.3421506683031718,
|
||||
"loss_pdi": 0.5879766543706259,
|
||||
"loss_ee": 0.5811398377021154,
|
||||
"loss_delivery": 1.0462109719713528,
|
||||
"loss_biodist": 0.520307645201683,
|
||||
"loss_toxic": 0.12977550799647966
|
||||
},
|
||||
{
|
||||
"loss": 2.861353278160095,
|
||||
"loss_size": 0.2742840300003688,
|
||||
"loss_pdi": 0.5437282969554266,
|
||||
"loss_ee": 0.5531725088755289,
|
||||
"loss_delivery": 0.9213679246604443,
|
||||
"loss_biodist": 0.4499489863713582,
|
||||
"loss_toxic": 0.11885150956610839
|
||||
},
|
||||
{
|
||||
"loss": 2.6909215847651162,
|
||||
"loss_size": 0.23881135260065398,
|
||||
"loss_pdi": 0.5229279547929764,
|
||||
"loss_ee": 0.5285524874925613,
|
||||
"loss_delivery": 0.8911051253477732,
|
||||
"loss_biodist": 0.4015616128842036,
|
||||
"loss_toxic": 0.10796305599311988
|
||||
},
|
||||
{
|
||||
"loss": 2.5927247206370034,
|
||||
"loss_size": 0.27356760079662007,
|
||||
"loss_pdi": 0.5166990955670675,
|
||||
"loss_ee": 0.5059170673290888,
|
||||
"loss_delivery": 0.8377179056406021,
|
||||
"loss_biodist": 0.3519642899433772,
|
||||
"loss_toxic": 0.10685871541500092
|
||||
},
|
||||
{
|
||||
"loss": 2.3971973856290183,
|
||||
"loss_size": 0.2688147674004237,
|
||||
"loss_pdi": 0.4851151605447133,
|
||||
"loss_ee": 0.47870688637097675,
|
||||
"loss_delivery": 0.7584750155607859,
|
||||
"loss_biodist": 0.3166690344611804,
|
||||
"loss_toxic": 0.08941652067005634
|
||||
},
|
||||
{
|
||||
"loss": 2.2271180947621665,
|
||||
"loss_size": 0.2559296215573947,
|
||||
"loss_pdi": 0.467803418636322,
|
||||
"loss_ee": 0.4819647620121638,
|
||||
"loss_delivery": 0.6487737223505974,
|
||||
"loss_biodist": 0.2930952211221059,
|
||||
"loss_toxic": 0.079551310899357
|
||||
},
|
||||
{
|
||||
"loss": 2.1467134952545166,
|
||||
"loss_size": 0.2658323546250661,
|
||||
"loss_pdi": 0.47287177046140033,
|
||||
"loss_ee": 0.4580538024504979,
|
||||
"loss_delivery": 0.6110207016269366,
|
||||
"loss_biodist": 0.26590356479088467,
|
||||
"loss_toxic": 0.07303123424450557
|
||||
},
|
||||
{
|
||||
"loss": 2.0699684421221414,
|
||||
"loss_size": 0.23655260602633157,
|
||||
"loss_pdi": 0.46446068088213605,
|
||||
"loss_ee": 0.43884341915448505,
|
||||
"loss_delivery": 0.5945644030968348,
|
||||
"loss_biodist": 0.26856863250335056,
|
||||
"loss_toxic": 0.06697871504972379
|
||||
},
|
||||
{
|
||||
"loss": 2.012367367744446,
|
||||
"loss_size": 0.20358355715870857,
|
||||
"loss_pdi": 0.44864421089490253,
|
||||
"loss_ee": 0.4260970900456111,
|
||||
"loss_delivery": 0.6111055202782154,
|
||||
"loss_biodist": 0.24829111248254776,
|
||||
"loss_toxic": 0.07464585608492295
|
||||
},
|
||||
{
|
||||
"loss": 1.9354575673739116,
|
||||
"loss_size": 0.19155597686767578,
|
||||
"loss_pdi": 0.43001438677310944,
|
||||
"loss_ee": 0.4029633104801178,
|
||||
"loss_delivery": 0.5866967861851057,
|
||||
"loss_biodist": 0.26284457246462506,
|
||||
"loss_toxic": 0.06138256782044967
|
||||
},
|
||||
{
|
||||
"loss": 1.9248821139335632,
|
||||
"loss_size": 0.19836385796467462,
|
||||
"loss_pdi": 0.43165912727514905,
|
||||
"loss_ee": 0.4223821411530177,
|
||||
"loss_delivery": 0.5774712382505337,
|
||||
"loss_biodist": 0.23008103668689728,
|
||||
"loss_toxic": 0.06492467441906531
|
||||
},
|
||||
{
|
||||
"loss": 1.7986130317052205,
|
||||
"loss_size": 0.1977602814634641,
|
||||
"loss_pdi": 0.4213625093301137,
|
||||
"loss_ee": 0.3969506522019704,
|
||||
"loss_delivery": 0.4972396679222584,
|
||||
"loss_biodist": 0.22815552850564322,
|
||||
"loss_toxic": 0.05714430411656698
|
||||
},
|
||||
{
|
||||
"loss": 1.8008437156677246,
|
||||
"loss_size": 0.20143492271502814,
|
||||
"loss_pdi": 0.4257240394751231,
|
||||
"loss_ee": 0.3939937750498454,
|
||||
"loss_delivery": 0.4996156108876069,
|
||||
"loss_biodist": 0.22945881386597952,
|
||||
"loss_toxic": 0.05061656702309847
|
||||
},
|
||||
{
|
||||
"loss": 1.8123606244723003,
|
||||
"loss_size": 0.23175274084011713,
|
||||
"loss_pdi": 0.41065867245197296,
|
||||
"loss_ee": 0.38645289838314056,
|
||||
"loss_delivery": 0.5105274474869171,
|
||||
"loss_biodist": 0.22759289046128592,
|
||||
"loss_toxic": 0.04537593169758717
|
||||
},
|
||||
{
|
||||
"loss": 1.85766206185023,
|
||||
"loss_size": 0.2237167110045751,
|
||||
"loss_pdi": 0.4198872745037079,
|
||||
"loss_ee": 0.39036936064561206,
|
||||
"loss_delivery": 0.5356200908621153,
|
||||
"loss_biodist": 0.2327907457947731,
|
||||
"loss_toxic": 0.055277835965777435
|
||||
},
|
||||
{
|
||||
"loss": 1.7299150824546814,
|
||||
"loss_size": 0.16866947089632353,
|
||||
"loss_pdi": 0.40182044357061386,
|
||||
"loss_ee": 0.37123599648475647,
|
||||
"loss_delivery": 0.5183743331581354,
|
||||
"loss_biodist": 0.22459317495425543,
|
||||
"loss_toxic": 0.04522162117063999
|
||||
},
|
||||
{
|
||||
"loss": 1.8115381598472595,
|
||||
"loss_size": 0.21021889025966325,
|
||||
"loss_pdi": 0.3938516428073247,
|
||||
"loss_ee": 0.3856282929579417,
|
||||
"loss_delivery": 0.5463737193495035,
|
||||
"loss_biodist": 0.22467835744222006,
|
||||
"loss_toxic": 0.05078731415172418
|
||||
},
|
||||
{
|
||||
"loss": 1.7609570423762004,
|
||||
"loss_size": 0.20672637100021043,
|
||||
"loss_pdi": 0.3868243644634883,
|
||||
"loss_ee": 0.3775654385487239,
|
||||
"loss_delivery": 0.5160359914104143,
|
||||
"loss_biodist": 0.22730938345193863,
|
||||
"loss_toxic": 0.04649555139864484
|
||||
}
|
||||
],
|
||||
"val": [
|
||||
{
|
||||
"loss": 19.896042142595565,
|
||||
"loss_size": 14.947636876787458,
|
||||
"loss_pdi": 1.3514722074781145,
|
||||
"loss_ee": 1.0372784308024816,
|
||||
"loss_delivery": 0.5157596128327506,
|
||||
"loss_biodist": 1.3665738276072912,
|
||||
"loss_toxic": 0.6773212381771633,
|
||||
"acc_pdi": 0.22564102564102564,
|
||||
"acc_ee": 0.4512820512820513,
|
||||
"acc_toxic": 0.7073170731707317
|
||||
},
|
||||
{
|
||||
"loss": 10.277108192443848,
|
||||
"loss_size": 5.728530270712716,
|
||||
"loss_pdi": 1.2047701733452933,
|
||||
"loss_ee": 1.013599353177207,
|
||||
"loss_delivery": 0.5329383058207375,
|
||||
"loss_biodist": 1.3288453817367554,
|
||||
"loss_toxic": 0.4684244394302368,
|
||||
"acc_pdi": 0.358974358974359,
|
||||
"acc_ee": 0.4461538461538462,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.947442190987723,
|
||||
"loss_size": 0.9055690084184919,
|
||||
"loss_pdi": 0.9325166344642639,
|
||||
"loss_ee": 1.071527932371412,
|
||||
"loss_delivery": 0.5525430504764829,
|
||||
"loss_biodist": 1.2514750446592058,
|
||||
"loss_toxic": 0.23381062703473227,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.4205128205128205,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.072668756757464,
|
||||
"loss_size": 0.18621658267719404,
|
||||
"loss_pdi": 0.7640595691544669,
|
||||
"loss_ee": 1.2155327456338065,
|
||||
"loss_delivery": 0.5523517067943301,
|
||||
"loss_biodist": 1.214196733066014,
|
||||
"loss_toxic": 0.14031140506267548,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.4205128205128205,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.053883859089443,
|
||||
"loss_size": 0.24764748875583922,
|
||||
"loss_pdi": 0.6859285916600909,
|
||||
"loss_ee": 1.3019903557641166,
|
||||
"loss_delivery": 0.5384655041354043,
|
||||
"loss_biodist": 1.191110406603132,
|
||||
"loss_toxic": 0.08874151536396571,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.4205128205128205,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.9850601128169467,
|
||||
"loss_size": 0.18243093735405377,
|
||||
"loss_pdi": 0.6606386282614299,
|
||||
"loss_ee": 1.2955879313605172,
|
||||
"loss_delivery": 0.6222422846726009,
|
||||
"loss_biodist": 1.1586603011403764,
|
||||
"loss_toxic": 0.06550001353025436,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.4205128205128205,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.6691461631229947,
|
||||
"loss_size": 0.18682856378810747,
|
||||
"loss_pdi": 0.6505168399640492,
|
||||
"loss_ee": 1.2279302733285087,
|
||||
"loss_delivery": 0.5776888344969068,
|
||||
"loss_biodist": 0.963392470564161,
|
||||
"loss_toxic": 0.06278917672378677,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.4205128205128205,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.6420267650059293,
|
||||
"loss_size": 0.1952320017984935,
|
||||
"loss_pdi": 0.6510439600263324,
|
||||
"loss_ee": 1.1954282522201538,
|
||||
"loss_delivery": 0.7644538623946053,
|
||||
"loss_biodist": 0.7812093198299408,
|
||||
"loss_toxic": 0.05465935756053243,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.35384615384615387,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.707563672746931,
|
||||
"loss_size": 0.2168926394411496,
|
||||
"loss_pdi": 0.6468359615121569,
|
||||
"loss_ee": 1.2361225570951189,
|
||||
"loss_delivery": 0.8388645563806806,
|
||||
"loss_biodist": 0.7232611009052822,
|
||||
"loss_toxic": 0.04558686592749187,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.37435897435897436,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.5122547830854143,
|
||||
"loss_size": 0.2614529473440988,
|
||||
"loss_pdi": 0.6344352598701205,
|
||||
"loss_ee": 1.2337199449539185,
|
||||
"loss_delivery": 0.7557642417294639,
|
||||
"loss_biodist": 0.5974612619195666,
|
||||
"loss_toxic": 0.029421218537858555,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.38461538461538464,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.085699932915824,
|
||||
"loss_size": 0.2200212436062949,
|
||||
"loss_pdi": 0.6225023801837649,
|
||||
"loss_ee": 1.1934180770601546,
|
||||
"loss_delivery": 1.4556497505732946,
|
||||
"loss_biodist": 0.5685178296906608,
|
||||
"loss_toxic": 0.02559054403432778,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.4205128205128205,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.8555148329053606,
|
||||
"loss_size": 0.25200862703578814,
|
||||
"loss_pdi": 0.6161004496472222,
|
||||
"loss_ee": 1.2155306509562902,
|
||||
"loss_delivery": 1.1912458751882826,
|
||||
"loss_biodist": 0.5578728743961879,
|
||||
"loss_toxic": 0.022756420208939483,
|
||||
"acc_pdi": 0.7128205128205128,
|
||||
"acc_ee": 0.4153846153846154,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.016897848674229,
|
||||
"loss_size": 0.24823468178510666,
|
||||
"loss_pdi": 0.6098369508981705,
|
||||
"loss_ee": 1.2048260143824987,
|
||||
"loss_delivery": 1.3509162408964974,
|
||||
"loss_biodist": 0.5822246244975499,
|
||||
"loss_toxic": 0.020859350051198686,
|
||||
"acc_pdi": 0.717948717948718,
|
||||
"acc_ee": 0.4153846153846154,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.9899745328085765,
|
||||
"loss_size": 0.27859273659331457,
|
||||
"loss_pdi": 0.6051683936800275,
|
||||
"loss_ee": 1.1875721216201782,
|
||||
"loss_delivery": 1.369072552238192,
|
||||
"loss_biodist": 0.5321473862443652,
|
||||
"loss_toxic": 0.017421354805784568,
|
||||
"acc_pdi": 0.717948717948718,
|
||||
"acc_ee": 0.40512820512820513,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.112551552908761,
|
||||
"loss_size": 0.23502862347023828,
|
||||
"loss_pdi": 0.6127274334430695,
|
||||
"loss_ee": 1.2102909428732735,
|
||||
"loss_delivery": 1.5001615626471383,
|
||||
"loss_biodist": 0.5411617543016162,
|
||||
"loss_toxic": 0.013181165459432773,
|
||||
"acc_pdi": 0.7025641025641025,
|
||||
"acc_ee": 0.4205128205128205,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.031796966280256,
|
||||
"loss_size": 0.27996559121779035,
|
||||
"loss_pdi": 0.6218498295971325,
|
||||
"loss_ee": 1.264663577079773,
|
||||
"loss_delivery": 1.225053642477308,
|
||||
"loss_biodist": 0.6304309921605247,
|
||||
"loss_toxic": 0.009833223053387232,
|
||||
"acc_pdi": 0.7128205128205128,
|
||||
"acc_ee": 0.4153846153846154,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.108100175857544,
|
||||
"loss_size": 0.2646343271647181,
|
||||
"loss_pdi": 0.6244613996573857,
|
||||
"loss_ee": 1.2785721336092268,
|
||||
"loss_delivery": 1.2817653375012534,
|
||||
"loss_biodist": 0.6491539776325226,
|
||||
"loss_toxic": 0.009512946475297213,
|
||||
"acc_pdi": 0.7025641025641025,
|
||||
"acc_ee": 0.4,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.237571648188999,
|
||||
"loss_size": 0.24608339369297028,
|
||||
"loss_pdi": 0.6387959729347911,
|
||||
"loss_ee": 1.263997495174408,
|
||||
"loss_delivery": 1.3677963316440582,
|
||||
"loss_biodist": 0.7113959235804421,
|
||||
"loss_toxic": 0.00950257752888969,
|
||||
"acc_pdi": 0.7025641025641025,
|
||||
"acc_ee": 0.40512820512820513,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.677373443331037,
|
||||
"loss_size": 0.25485377439430784,
|
||||
"loss_pdi": 0.6646602579525539,
|
||||
"loss_ee": 1.2957249539239066,
|
||||
"loss_delivery": 1.609152581010546,
|
||||
"loss_biodist": 0.8443670613425118,
|
||||
"loss_toxic": 0.008614770535911833,
|
||||
"acc_pdi": 0.7025641025641025,
|
||||
"acc_ee": 0.4,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.661478791918073,
|
||||
"loss_size": 0.22112409876925604,
|
||||
"loss_pdi": 0.6840873499001775,
|
||||
"loss_ee": 1.345719371523176,
|
||||
"loss_delivery": 1.5750049437795366,
|
||||
"loss_biodist": 0.8278610365731376,
|
||||
"loss_toxic": 0.0076820029810603175,
|
||||
"acc_pdi": 0.6974358974358974,
|
||||
"acc_ee": 0.39487179487179486,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.7805344717843195,
|
||||
"loss_size": 0.23743291412081038,
|
||||
"loss_pdi": 0.6911627639617238,
|
||||
"loss_ee": 1.3796877009528024,
|
||||
"loss_delivery": 1.6191245743206568,
|
||||
"loss_biodist": 0.8459444258894239,
|
||||
"loss_toxic": 0.00718201809961881,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.4,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 5.050536905016218,
|
||||
"loss_size": 0.2340430138366563,
|
||||
"loss_pdi": 0.7009050281984466,
|
||||
"loss_ee": 1.3757387569972448,
|
||||
"loss_delivery": 1.760018629687173,
|
||||
"loss_biodist": 0.9723425933292934,
|
||||
"loss_toxic": 0.007489000846232686,
|
||||
"acc_pdi": 0.7128205128205128,
|
||||
"acc_ee": 0.39487179487179486,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 5.172625916344779,
|
||||
"loss_size": 0.23549717558281763,
|
||||
"loss_pdi": 0.6980976568801063,
|
||||
"loss_ee": 1.357451047216143,
|
||||
"loss_delivery": 1.893591480595725,
|
||||
"loss_biodist": 0.9802024279321943,
|
||||
"loss_toxic": 0.0077861944612647805,
|
||||
"acc_pdi": 0.717948717948718,
|
||||
"acc_ee": 0.38974358974358975,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 5.048826490129743,
|
||||
"loss_size": 0.2420537450483867,
|
||||
"loss_pdi": 0.7013733184763363,
|
||||
"loss_ee": 1.353548560823713,
|
||||
"loss_delivery": 1.7931698901312692,
|
||||
"loss_biodist": 0.9509889696325574,
|
||||
"loss_toxic": 0.007692053714501006,
|
||||
"acc_pdi": 0.7128205128205128,
|
||||
"acc_ee": 0.38974358974358975,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 4.951304980686733,
|
||||
"loss_size": 0.24394649054322923,
|
||||
"loss_pdi": 0.7197230202811105,
|
||||
"loss_ee": 1.3789095027106149,
|
||||
"loss_delivery": 1.6561105762209212,
|
||||
"loss_biodist": 0.9460095167160034,
|
||||
"loss_toxic": 0.006605847106714334,
|
||||
"acc_pdi": 0.7076923076923077,
|
||||
"acc_ee": 0.38974358974358975,
|
||||
"acc_toxic": 1.0
|
||||
}
|
||||
]
|
||||
}
|
||||
BIN
models/finetune_cv/fold_2/model.pt
Normal file
BIN
models/finetune_cv/fold_2/model.pt
Normal file
Binary file not shown.
447
models/finetune_cv/fold_3/history.json
Normal file
447
models/finetune_cv/fold_3/history.json
Normal file
@ -0,0 +1,447 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 19.317236709594727,
|
||||
"loss_size": 14.108779287338256,
|
||||
"loss_pdi": 1.2223044037818909,
|
||||
"loss_ee": 1.1201724767684937,
|
||||
"loss_delivery": 1.0989456713199615,
|
||||
"loss_biodist": 1.2243955612182618,
|
||||
"loss_toxic": 0.5426396489143371
|
||||
},
|
||||
{
|
||||
"loss": 7.058736562728882,
|
||||
"loss_size": 2.407193088531494,
|
||||
"loss_pdi": 1.0396445631980895,
|
||||
"loss_ee": 1.043432891368866,
|
||||
"loss_delivery": 1.183875671029091,
|
||||
"loss_biodist": 1.0435532987117768,
|
||||
"loss_toxic": 0.3410369783639908
|
||||
},
|
||||
{
|
||||
"loss": 4.30960967540741,
|
||||
"loss_size": 0.3954987198114395,
|
||||
"loss_pdi": 0.8135693371295929,
|
||||
"loss_ee": 0.9796469569206238,
|
||||
"loss_delivery": 1.0894740536808967,
|
||||
"loss_biodist": 0.8407172799110413,
|
||||
"loss_toxic": 0.1907033085823059
|
||||
},
|
||||
{
|
||||
"loss": 3.814995551109314,
|
||||
"loss_size": 0.3033123552799225,
|
||||
"loss_pdi": 0.7126355469226837,
|
||||
"loss_ee": 0.9527663111686706,
|
||||
"loss_delivery": 1.0963637247681617,
|
||||
"loss_biodist": 0.613499540090561,
|
||||
"loss_toxic": 0.13641793020069598
|
||||
},
|
||||
{
|
||||
"loss": 3.4925455331802366,
|
||||
"loss_size": 0.31617238447070123,
|
||||
"loss_pdi": 0.6569374144077301,
|
||||
"loss_ee": 0.9212668299674988,
|
||||
"loss_delivery": 1.0250366538763047,
|
||||
"loss_biodist": 0.4527473896741867,
|
||||
"loss_toxic": 0.12038486637175083
|
||||
},
|
||||
{
|
||||
"loss": 3.255272912979126,
|
||||
"loss_size": 0.27022836059331895,
|
||||
"loss_pdi": 0.6289663434028625,
|
||||
"loss_ee": 0.9047561466693879,
|
||||
"loss_delivery": 0.9742588266730309,
|
||||
"loss_biodist": 0.3838836058974266,
|
||||
"loss_toxic": 0.09317965283989907
|
||||
},
|
||||
{
|
||||
"loss": 3.281973719596863,
|
||||
"loss_size": 0.26578598394989966,
|
||||
"loss_pdi": 0.5881433010101318,
|
||||
"loss_ee": 0.8700660288333892,
|
||||
"loss_delivery": 1.1519657298922539,
|
||||
"loss_biodist": 0.3240764126181602,
|
||||
"loss_toxic": 0.08193630240857601
|
||||
},
|
||||
{
|
||||
"loss": 2.7810576915740968,
|
||||
"loss_size": 0.2505718767642975,
|
||||
"loss_pdi": 0.5593925356864929,
|
||||
"loss_ee": 0.8196510195732116,
|
||||
"loss_delivery": 0.7932070523500443,
|
||||
"loss_biodist": 0.27653754949569703,
|
||||
"loss_toxic": 0.08169759791344404
|
||||
},
|
||||
{
|
||||
"loss": 2.644732141494751,
|
||||
"loss_size": 0.27979295402765275,
|
||||
"loss_pdi": 0.5457461476325989,
|
||||
"loss_ee": 0.7845215618610382,
|
||||
"loss_delivery": 0.722122372686863,
|
||||
"loss_biodist": 0.2437703028321266,
|
||||
"loss_toxic": 0.06877880096435547
|
||||
},
|
||||
{
|
||||
"loss": 2.5743841886520387,
|
||||
"loss_size": 0.21236803606152535,
|
||||
"loss_pdi": 0.5281321376562118,
|
||||
"loss_ee": 0.7772053182125092,
|
||||
"loss_delivery": 0.7842913195490837,
|
||||
"loss_biodist": 0.20931598618626596,
|
||||
"loss_toxic": 0.0630713876336813
|
||||
},
|
||||
{
|
||||
"loss": 2.493379771709442,
|
||||
"loss_size": 0.2545281477272511,
|
||||
"loss_pdi": 0.514763566851616,
|
||||
"loss_ee": 0.7416582465171814,
|
||||
"loss_delivery": 0.7315813854336739,
|
||||
"loss_biodist": 0.18844463676214218,
|
||||
"loss_toxic": 0.062403830140829085
|
||||
},
|
||||
{
|
||||
"loss": 2.3714203119277952,
|
||||
"loss_size": 0.21288565024733544,
|
||||
"loss_pdi": 0.5149440914392471,
|
||||
"loss_ee": 0.7432775914669036,
|
||||
"loss_delivery": 0.6615208894014358,
|
||||
"loss_biodist": 0.17799324095249175,
|
||||
"loss_toxic": 0.06079882858321071
|
||||
},
|
||||
{
|
||||
"loss": 2.3138927936553957,
|
||||
"loss_size": 0.22406778559088708,
|
||||
"loss_pdi": 0.5060430943965912,
|
||||
"loss_ee": 0.7270951688289642,
|
||||
"loss_delivery": 0.6268678307533264,
|
||||
"loss_biodist": 0.17946239709854125,
|
||||
"loss_toxic": 0.050356499617919326
|
||||
},
|
||||
{
|
||||
"loss": 2.2404407501220702,
|
||||
"loss_size": 0.23460092321038245,
|
||||
"loss_pdi": 0.4892877459526062,
|
||||
"loss_ee": 0.6908941030502319,
|
||||
"loss_delivery": 0.6124202072620392,
|
||||
"loss_biodist": 0.16842604279518128,
|
||||
"loss_toxic": 0.044811736792325974
|
||||
},
|
||||
{
|
||||
"loss": 2.2448294520378114,
|
||||
"loss_size": 0.21119624376296997,
|
||||
"loss_pdi": 0.479864901304245,
|
||||
"loss_ee": 0.6906192302703857,
|
||||
"loss_delivery": 0.6555144399404526,
|
||||
"loss_biodist": 0.16310803219676018,
|
||||
"loss_toxic": 0.044526621932163835
|
||||
},
|
||||
{
|
||||
"loss": 2.1580574989318846,
|
||||
"loss_size": 0.18697498068213464,
|
||||
"loss_pdi": 0.48660930395126345,
|
||||
"loss_ee": 0.6810935467481614,
|
||||
"loss_delivery": 0.6051739566028118,
|
||||
"loss_biodist": 0.15406969040632248,
|
||||
"loss_toxic": 0.04413598729297519
|
||||
},
|
||||
{
|
||||
"loss": 2.114891529083252,
|
||||
"loss_size": 0.17799586579203605,
|
||||
"loss_pdi": 0.4589719235897064,
|
||||
"loss_ee": 0.6686563313007354,
|
||||
"loss_delivery": 0.6179293170571327,
|
||||
"loss_biodist": 0.1526280902326107,
|
||||
"loss_toxic": 0.03870999766513705
|
||||
},
|
||||
{
|
||||
"loss": 2.1680126667022703,
|
||||
"loss_size": 0.18272313922643663,
|
||||
"loss_pdi": 0.47693236321210863,
|
||||
"loss_ee": 0.6723115026950837,
|
||||
"loss_delivery": 0.6574018053710461,
|
||||
"loss_biodist": 0.14306045994162558,
|
||||
"loss_toxic": 0.03558342705946416
|
||||
},
|
||||
{
|
||||
"loss": 2.0243090748786927,
|
||||
"loss_size": 0.19451010078191758,
|
||||
"loss_pdi": 0.46296934187412264,
|
||||
"loss_ee": 0.6654580652713775,
|
||||
"loss_delivery": 0.5195972554385662,
|
||||
"loss_biodist": 0.1416195034980774,
|
||||
"loss_toxic": 0.04015482016839087
|
||||
},
|
||||
{
|
||||
"loss": 1.980038857460022,
|
||||
"loss_size": 0.19992023780941964,
|
||||
"loss_pdi": 0.4373833805322647,
|
||||
"loss_ee": 0.6562270969152451,
|
||||
"loss_delivery": 0.5170416861772538,
|
||||
"loss_biodist": 0.13248837292194365,
|
||||
"loss_toxic": 0.03697808152064681
|
||||
},
|
||||
{
|
||||
"loss": 2.0073827385902403,
|
||||
"loss_size": 0.17545675858855247,
|
||||
"loss_pdi": 0.43559625148773196,
|
||||
"loss_ee": 0.6394164443016053,
|
||||
"loss_delivery": 0.5809337809681893,
|
||||
"loss_biodist": 0.13504885137081146,
|
||||
"loss_toxic": 0.040930699557065964
|
||||
}
|
||||
],
|
||||
"val": [
|
||||
{
|
||||
"loss": 10.945204257965088,
|
||||
"loss_size": 6.681218147277832,
|
||||
"loss_pdi": 1.0216107964515686,
|
||||
"loss_ee": 1.0486068725585938,
|
||||
"loss_delivery": 0.4687899202108383,
|
||||
"loss_biodist": 1.215298354625702,
|
||||
"loss_toxic": 0.5096809715032578,
|
||||
"acc_pdi": 0.8823529411764706,
|
||||
"acc_ee": 0.6470588235294118,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 4.0456085205078125,
|
||||
"loss_size": 0.3470493406057358,
|
||||
"loss_pdi": 0.7843169867992401,
|
||||
"loss_ee": 0.8336820006370544,
|
||||
"loss_delivery": 0.42519159615039825,
|
||||
"loss_biodist": 1.1528617143630981,
|
||||
"loss_toxic": 0.5025068372488022,
|
||||
"acc_pdi": 0.8823529411764706,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.553251266479492,
|
||||
"loss_size": 0.07565776817500591,
|
||||
"loss_pdi": 0.5630811750888824,
|
||||
"loss_ee": 0.6947644650936127,
|
||||
"loss_delivery": 0.4020952582359314,
|
||||
"loss_biodist": 1.1898333430290222,
|
||||
"loss_toxic": 0.627819336950779,
|
||||
"acc_pdi": 0.8823529411764706,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.2876256704330444,
|
||||
"loss_size": 0.056399866938591,
|
||||
"loss_pdi": 0.579200953245163,
|
||||
"loss_ee": 0.5947848558425903,
|
||||
"loss_delivery": 0.4561047703027725,
|
||||
"loss_biodist": 1.0357274413108826,
|
||||
"loss_toxic": 0.5654077678918839,
|
||||
"acc_pdi": 0.8823529411764706,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.2625988721847534,
|
||||
"loss_size": 0.11296019703149796,
|
||||
"loss_pdi": 0.5352367609739304,
|
||||
"loss_ee": 0.6021667718887329,
|
||||
"loss_delivery": 0.5046610683202744,
|
||||
"loss_biodist": 1.0080225467681885,
|
||||
"loss_toxic": 0.4995514266192913,
|
||||
"acc_pdi": 0.8823529411764706,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.165306806564331,
|
||||
"loss_size": 0.09038551151752472,
|
||||
"loss_pdi": 0.5058617442846298,
|
||||
"loss_ee": 0.6476156711578369,
|
||||
"loss_delivery": 0.43027013540267944,
|
||||
"loss_biodist": 1.0445645153522491,
|
||||
"loss_toxic": 0.4466092698276043,
|
||||
"acc_pdi": 0.8823529411764706,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.2408690452575684,
|
||||
"loss_size": 0.22243183851242065,
|
||||
"loss_pdi": 0.4985402673482895,
|
||||
"loss_ee": 0.571824312210083,
|
||||
"loss_delivery": 0.43825456500053406,
|
||||
"loss_biodist": 0.9937507510185242,
|
||||
"loss_toxic": 0.5160673335194588,
|
||||
"acc_pdi": 0.8823529411764706,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.5194804668426514,
|
||||
"loss_size": 0.36968255043029785,
|
||||
"loss_pdi": 0.4991031885147095,
|
||||
"loss_ee": 0.5797468274831772,
|
||||
"loss_delivery": 0.5644859671592712,
|
||||
"loss_biodist": 0.9723091125488281,
|
||||
"loss_toxic": 0.5341527052223682,
|
||||
"acc_pdi": 0.8823529411764706,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.539685606956482,
|
||||
"loss_size": 0.38313066959381104,
|
||||
"loss_pdi": 0.528433233499527,
|
||||
"loss_ee": 0.5810057669878006,
|
||||
"loss_delivery": 0.44039086997509,
|
||||
"loss_biodist": 1.0017918348312378,
|
||||
"loss_toxic": 0.6049331650137901,
|
||||
"acc_pdi": 0.8823529411764706,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.9403724670410156,
|
||||
"loss_size": 0.6225972771644592,
|
||||
"loss_pdi": 0.5688649713993073,
|
||||
"loss_ee": 0.6205386221408844,
|
||||
"loss_delivery": 0.6095166206359863,
|
||||
"loss_biodist": 0.9419751763343811,
|
||||
"loss_toxic": 0.576879795640707,
|
||||
"acc_pdi": 0.803921568627451,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.8653494119644165,
|
||||
"loss_size": 0.6294703483581543,
|
||||
"loss_pdi": 0.5615053772926331,
|
||||
"loss_ee": 0.6072992980480194,
|
||||
"loss_delivery": 0.47824281454086304,
|
||||
"loss_biodist": 0.964938759803772,
|
||||
"loss_toxic": 0.6238927394151688,
|
||||
"acc_pdi": 0.8431372549019608,
|
||||
"acc_ee": 0.8431372549019608,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 4.109289169311523,
|
||||
"loss_size": 0.753549188375473,
|
||||
"loss_pdi": 0.6232334971427917,
|
||||
"loss_ee": 0.6752453744411469,
|
||||
"loss_delivery": 0.4541686922311783,
|
||||
"loss_biodist": 1.0001222491264343,
|
||||
"loss_toxic": 0.6029700562357903,
|
||||
"acc_pdi": 0.7450980392156863,
|
||||
"acc_ee": 0.7254901960784313,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 4.084217548370361,
|
||||
"loss_size": 0.6689053475856781,
|
||||
"loss_pdi": 0.5947604179382324,
|
||||
"loss_ee": 0.6819752305746078,
|
||||
"loss_delivery": 0.5174736380577087,
|
||||
"loss_biodist": 0.9870622158050537,
|
||||
"loss_toxic": 0.6340407878160477,
|
||||
"acc_pdi": 0.7843137254901961,
|
||||
"acc_ee": 0.7647058823529411,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.762814164161682,
|
||||
"loss_size": 0.5682831406593323,
|
||||
"loss_pdi": 0.5777421444654465,
|
||||
"loss_ee": 0.7156199663877487,
|
||||
"loss_delivery": 0.44971026480197906,
|
||||
"loss_biodist": 0.9156049191951752,
|
||||
"loss_toxic": 0.5358536541461945,
|
||||
"acc_pdi": 0.7450980392156863,
|
||||
"acc_ee": 0.7058823529411765,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 4.147223711013794,
|
||||
"loss_size": 0.6644828915596008,
|
||||
"loss_pdi": 0.5911359935998917,
|
||||
"loss_ee": 0.713784396648407,
|
||||
"loss_delivery": 0.500703439116478,
|
||||
"loss_biodist": 1.0310384333133698,
|
||||
"loss_toxic": 0.6460786163806915,
|
||||
"acc_pdi": 0.7647058823529411,
|
||||
"acc_ee": 0.7058823529411765,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 4.157853841781616,
|
||||
"loss_size": 0.7414849102497101,
|
||||
"loss_pdi": 0.630668580532074,
|
||||
"loss_ee": 0.6938402056694031,
|
||||
"loss_delivery": 0.50765261054039,
|
||||
"loss_biodist": 0.9891600012779236,
|
||||
"loss_toxic": 0.5950475558638573,
|
||||
"acc_pdi": 0.7450980392156863,
|
||||
"acc_ee": 0.7450980392156863,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 4.2473719120025635,
|
||||
"loss_size": 0.8058429956436157,
|
||||
"loss_pdi": 0.5982940196990967,
|
||||
"loss_ee": 0.7209844589233398,
|
||||
"loss_delivery": 0.5253763496875763,
|
||||
"loss_biodist": 0.973820835351944,
|
||||
"loss_toxic": 0.6230533868074417,
|
||||
"acc_pdi": 0.803921568627451,
|
||||
"acc_ee": 0.6862745098039216,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 3.90485680103302,
|
||||
"loss_size": 0.5400884747505188,
|
||||
"loss_pdi": 0.5671974420547485,
|
||||
"loss_ee": 0.7085212767124176,
|
||||
"loss_delivery": 0.5078988373279572,
|
||||
"loss_biodist": 0.9940473735332489,
|
||||
"loss_toxic": 0.5871035009622574,
|
||||
"acc_pdi": 0.803921568627451,
|
||||
"acc_ee": 0.7450980392156863,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 4.193094968795776,
|
||||
"loss_size": 0.764777421951294,
|
||||
"loss_pdi": 0.5734306275844574,
|
||||
"loss_ee": 0.7070393562316895,
|
||||
"loss_delivery": 0.5335722267627716,
|
||||
"loss_biodist": 1.0060182809829712,
|
||||
"loss_toxic": 0.6082571670413017,
|
||||
"acc_pdi": 0.8235294117647058,
|
||||
"acc_ee": 0.7450980392156863,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 4.225732326507568,
|
||||
"loss_size": 0.7807798981666565,
|
||||
"loss_pdi": 0.57969930768013,
|
||||
"loss_ee": 0.7046914398670197,
|
||||
"loss_delivery": 0.5619150400161743,
|
||||
"loss_biodist": 1.0033797025680542,
|
||||
"loss_toxic": 0.5952669233083725,
|
||||
"acc_pdi": 0.7843137254901961,
|
||||
"acc_ee": 0.7450980392156863,
|
||||
"acc_toxic": 0.851063829787234
|
||||
},
|
||||
{
|
||||
"loss": 4.122166633605957,
|
||||
"loss_size": 0.727344423532486,
|
||||
"loss_pdi": 0.5855642706155777,
|
||||
"loss_ee": 0.7065156400203705,
|
||||
"loss_delivery": 0.5201583206653595,
|
||||
"loss_biodist": 0.9953437149524689,
|
||||
"loss_toxic": 0.5872401967644691,
|
||||
"acc_pdi": 0.803921568627451,
|
||||
"acc_ee": 0.7450980392156863,
|
||||
"acc_toxic": 0.851063829787234
|
||||
}
|
||||
]
|
||||
}
|
||||
BIN
models/finetune_cv/fold_3/model.pt
Normal file
BIN
models/finetune_cv/fold_3/model.pt
Normal file
Binary file not shown.
384
models/finetune_cv/fold_4/history.json
Normal file
384
models/finetune_cv/fold_4/history.json
Normal file
@ -0,0 +1,384 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 19.51876787705855,
|
||||
"loss_size": 14.430041096427225,
|
||||
"loss_pdi": 1.3163191405209629,
|
||||
"loss_ee": 1.080450177192688,
|
||||
"loss_delivery": 0.8366337662393396,
|
||||
"loss_biodist": 1.2393495603041216,
|
||||
"loss_toxic": 0.6159739873626016
|
||||
},
|
||||
{
|
||||
"loss": 7.450059110468084,
|
||||
"loss_size": 3.0973785898902197,
|
||||
"loss_pdi": 1.0574429847977378,
|
||||
"loss_ee": 0.972686382857236,
|
||||
"loss_delivery": 0.8131051341241057,
|
||||
"loss_biodist": 1.0609121160073713,
|
||||
"loss_toxic": 0.4485339197245511
|
||||
},
|
||||
{
|
||||
"loss": 4.450224074450406,
|
||||
"loss_size": 0.4227097576314753,
|
||||
"loss_pdi": 0.8444447679953142,
|
||||
"loss_ee": 0.9470934163440358,
|
||||
"loss_delivery": 1.0848771997473456,
|
||||
"loss_biodist": 0.8853748061440208,
|
||||
"loss_toxic": 0.26572414352135226
|
||||
},
|
||||
{
|
||||
"loss": 3.697209119796753,
|
||||
"loss_size": 0.27599087357521057,
|
||||
"loss_pdi": 0.7228363969109275,
|
||||
"loss_ee": 0.9080322655764493,
|
||||
"loss_delivery": 0.8549216186458414,
|
||||
"loss_biodist": 0.7190906730565158,
|
||||
"loss_toxic": 0.21633729677308688
|
||||
},
|
||||
{
|
||||
"loss": 3.5547448938543145,
|
||||
"loss_size": 0.29429094425656577,
|
||||
"loss_pdi": 0.6620089682665738,
|
||||
"loss_ee": 0.8595339439131997,
|
||||
"loss_delivery": 0.9978644847869873,
|
||||
"loss_biodist": 0.5649206800894304,
|
||||
"loss_toxic": 0.17612605541944504
|
||||
},
|
||||
{
|
||||
"loss": 3.069189115004106,
|
||||
"loss_size": 0.30328830602494156,
|
||||
"loss_pdi": 0.6169473230838776,
|
||||
"loss_ee": 0.8132463910362937,
|
||||
"loss_delivery": 0.72551099143245,
|
||||
"loss_biodist": 0.46287189017642627,
|
||||
"loss_toxic": 0.14732415906407617
|
||||
},
|
||||
{
|
||||
"loss": 3.1349260156804863,
|
||||
"loss_size": 0.27535233172503387,
|
||||
"loss_pdi": 0.5893999202684923,
|
||||
"loss_ee": 0.807813747362657,
|
||||
"loss_delivery": 0.9538742283528502,
|
||||
"loss_biodist": 0.4018077091737227,
|
||||
"loss_toxic": 0.1066781035201116
|
||||
},
|
||||
{
|
||||
"loss": 2.6963415037501943,
|
||||
"loss_size": 0.276088840582154,
|
||||
"loss_pdi": 0.5490301495248621,
|
||||
"loss_ee": 0.757920276034962,
|
||||
"loss_delivery": 0.6706068068742752,
|
||||
"loss_biodist": 0.34504769336093555,
|
||||
"loss_toxic": 0.09764780781485817
|
||||
},
|
||||
{
|
||||
"loss": 2.418043158271096,
|
||||
"loss_size": 0.25799565559083765,
|
||||
"loss_pdi": 0.5286295684901151,
|
||||
"loss_ee": 0.73835120417855,
|
||||
"loss_delivery": 0.5078353543173183,
|
||||
"loss_biodist": 0.2939920425415039,
|
||||
"loss_toxic": 0.09123929352922873
|
||||
},
|
||||
{
|
||||
"loss": 2.294130650433627,
|
||||
"loss_size": 0.20914554325017062,
|
||||
"loss_pdi": 0.5159178945151243,
|
||||
"loss_ee": 0.7331724437800321,
|
||||
"loss_delivery": 0.50414734875614,
|
||||
"loss_biodist": 0.2559647980061444,
|
||||
"loss_toxic": 0.07578264227644964
|
||||
},
|
||||
{
|
||||
"loss": 2.260723189874129,
|
||||
"loss_size": 0.21194299920038742,
|
||||
"loss_pdi": 0.51129734787074,
|
||||
"loss_ee": 0.71018939668482,
|
||||
"loss_delivery": 0.506003974513574,
|
||||
"loss_biodist": 0.24507361785932022,
|
||||
"loss_toxic": 0.07621594924818385
|
||||
},
|
||||
{
|
||||
"loss": 2.180013732476668,
|
||||
"loss_size": 0.21782933243296362,
|
||||
"loss_pdi": 0.5094848789952018,
|
||||
"loss_ee": 0.6991291533816945,
|
||||
"loss_delivery": 0.4610004154118625,
|
||||
"loss_biodist": 0.22516351396387274,
|
||||
"loss_toxic": 0.06740644057704644
|
||||
},
|
||||
{
|
||||
"loss": 2.131091995672746,
|
||||
"loss_size": 0.22081551971760663,
|
||||
"loss_pdi": 0.4984923790801655,
|
||||
"loss_ee": 0.6744041009382769,
|
||||
"loss_delivery": 0.44286114688624034,
|
||||
"loss_biodist": 0.2276345125653527,
|
||||
"loss_toxic": 0.0668843225999312
|
||||
},
|
||||
{
|
||||
"loss": 2.075855114243247,
|
||||
"loss_size": 0.1967273937030272,
|
||||
"loss_pdi": 0.4761518023230813,
|
||||
"loss_ee": 0.6580501876094125,
|
||||
"loss_delivery": 0.4651151258837093,
|
||||
"loss_biodist": 0.22098628905686465,
|
||||
"loss_toxic": 0.05882429365407337
|
||||
},
|
||||
{
|
||||
"loss": 2.070832209153609,
|
||||
"loss_size": 0.22619221427223898,
|
||||
"loss_pdi": 0.46735330332409253,
|
||||
"loss_ee": 0.658088050105355,
|
||||
"loss_delivery": 0.46364551173015073,
|
||||
"loss_biodist": 0.1992648494514552,
|
||||
"loss_toxic": 0.0562882690097798
|
||||
},
|
||||
{
|
||||
"loss": 2.0497749502008613,
|
||||
"loss_size": 0.2018005665053021,
|
||||
"loss_pdi": 0.44933846592903137,
|
||||
"loss_ee": 0.6409396637569774,
|
||||
"loss_delivery": 0.4817944710904902,
|
||||
"loss_biodist": 0.21351950141516599,
|
||||
"loss_toxic": 0.0623823038556359
|
||||
},
|
||||
{
|
||||
"loss": 1.998817953196439,
|
||||
"loss_size": 0.19020642136985605,
|
||||
"loss_pdi": 0.4579390991817821,
|
||||
"loss_ee": 0.6345709101720289,
|
||||
"loss_delivery": 0.45773128284649417,
|
||||
"loss_biodist": 0.20633393864740024,
|
||||
"loss_toxic": 0.052036324177276
|
||||
},
|
||||
{
|
||||
"loss": 1.9732873006300493,
|
||||
"loss_size": 0.18214904246005145,
|
||||
"loss_pdi": 0.46314583312381397,
|
||||
"loss_ee": 0.6480948545716025,
|
||||
"loss_delivery": 0.43798652359030465,
|
||||
"loss_biodist": 0.19162344119765543,
|
||||
"loss_toxic": 0.05028762956234542
|
||||
}
|
||||
],
|
||||
"val": [
|
||||
{
|
||||
"loss": 11.670351346333822,
|
||||
"loss_size": 7.447449843088786,
|
||||
"loss_pdi": 1.082938313484192,
|
||||
"loss_ee": 0.9422469735145569,
|
||||
"loss_delivery": 0.7185012102127075,
|
||||
"loss_biodist": 0.945093979438146,
|
||||
"loss_toxic": 0.5341211756070455,
|
||||
"acc_pdi": 0.8484848484848485,
|
||||
"acc_ee": 0.8333333333333334,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.7281323273976645,
|
||||
"loss_size": 0.4031377931435903,
|
||||
"loss_pdi": 0.7136062383651733,
|
||||
"loss_ee": 0.7757165829340616,
|
||||
"loss_delivery": 0.6889261901378632,
|
||||
"loss_biodist": 0.8803721169630686,
|
||||
"loss_toxic": 0.2663734555244446,
|
||||
"acc_pdi": 0.8484848484848485,
|
||||
"acc_ee": 0.8333333333333334,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 2.996154228846232,
|
||||
"loss_size": 0.08968868106603622,
|
||||
"loss_pdi": 0.5251734455426534,
|
||||
"loss_ee": 0.705430785814921,
|
||||
"loss_delivery": 0.7259814739227295,
|
||||
"loss_biodist": 0.8580058316389719,
|
||||
"loss_toxic": 0.09187404563029607,
|
||||
"acc_pdi": 0.8484848484848485,
|
||||
"acc_ee": 0.8333333333333334,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.0241867701212564,
|
||||
"loss_size": 0.06373066206773122,
|
||||
"loss_pdi": 0.5075281461079916,
|
||||
"loss_ee": 0.6890556613604227,
|
||||
"loss_delivery": 0.8740459084510803,
|
||||
"loss_biodist": 0.8290574749310812,
|
||||
"loss_toxic": 0.060768917202949524,
|
||||
"acc_pdi": 0.8484848484848485,
|
||||
"acc_ee": 0.8333333333333334,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.2901877562204995,
|
||||
"loss_size": 0.20215384662151337,
|
||||
"loss_pdi": 0.5433482428391775,
|
||||
"loss_ee": 0.7523069183031718,
|
||||
"loss_delivery": 0.9290379285812378,
|
||||
"loss_biodist": 0.8056556979815165,
|
||||
"loss_toxic": 0.057684975365797676,
|
||||
"acc_pdi": 0.8484848484848485,
|
||||
"acc_ee": 0.8333333333333334,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.11362091700236,
|
||||
"loss_size": 0.11133595556020737,
|
||||
"loss_pdi": 0.5679469505945841,
|
||||
"loss_ee": 0.8252793351809183,
|
||||
"loss_delivery": 0.8343843619028727,
|
||||
"loss_biodist": 0.7218613227208456,
|
||||
"loss_toxic": 0.052812947581211724,
|
||||
"acc_pdi": 0.8484848484848485,
|
||||
"acc_ee": 0.8181818181818182,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.247321446736654,
|
||||
"loss_size": 0.09277657171090443,
|
||||
"loss_pdi": 0.6567125717798868,
|
||||
"loss_ee": 1.0444208979606628,
|
||||
"loss_delivery": 0.7062844236691793,
|
||||
"loss_biodist": 0.6986102362473806,
|
||||
"loss_toxic": 0.04851680745681127,
|
||||
"acc_pdi": 0.8181818181818182,
|
||||
"acc_ee": 0.45454545454545453,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.168424208958944,
|
||||
"loss_size": 0.05177713930606842,
|
||||
"loss_pdi": 0.5932339330514272,
|
||||
"loss_ee": 0.968136191368103,
|
||||
"loss_delivery": 0.7594618300596873,
|
||||
"loss_biodist": 0.7631273567676544,
|
||||
"loss_toxic": 0.032687741021315254,
|
||||
"acc_pdi": 0.8484848484848485,
|
||||
"acc_ee": 0.45454545454545453,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.1226733525594077,
|
||||
"loss_size": 0.16442706187566122,
|
||||
"loss_pdi": 0.4861932198206584,
|
||||
"loss_ee": 0.8927785356839498,
|
||||
"loss_delivery": 0.809323231379191,
|
||||
"loss_biodist": 0.7391951779524485,
|
||||
"loss_toxic": 0.030756143853068352,
|
||||
"acc_pdi": 0.8484848484848485,
|
||||
"acc_ee": 0.4696969696969697,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.4750285943349204,
|
||||
"loss_size": 0.07188746457298596,
|
||||
"loss_pdi": 0.635328451792399,
|
||||
"loss_ee": 1.0510863463083904,
|
||||
"loss_delivery": 0.9019280473391215,
|
||||
"loss_biodist": 0.7850770453612009,
|
||||
"loss_toxic": 0.02972123461465041,
|
||||
"acc_pdi": 0.8181818181818182,
|
||||
"acc_ee": 0.3787878787878788,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.3856334686279297,
|
||||
"loss_size": 0.0857744167248408,
|
||||
"loss_pdi": 0.5498887598514557,
|
||||
"loss_ee": 0.9292206565539042,
|
||||
"loss_delivery": 1.003569980462392,
|
||||
"loss_biodist": 0.7962689697742462,
|
||||
"loss_toxic": 0.020910644593338173,
|
||||
"acc_pdi": 0.8333333333333334,
|
||||
"acc_ee": 0.3787878787878788,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.594546397527059,
|
||||
"loss_size": 0.058230139315128326,
|
||||
"loss_pdi": 0.7094775040944418,
|
||||
"loss_ee": 1.134681224822998,
|
||||
"loss_delivery": 0.8755488395690918,
|
||||
"loss_biodist": 0.7928893665472666,
|
||||
"loss_toxic": 0.023719362293680508,
|
||||
"acc_pdi": 0.6363636363636364,
|
||||
"acc_ee": 0.21212121212121213,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.34331480662028,
|
||||
"loss_size": 0.09952588627735774,
|
||||
"loss_pdi": 0.5444782872994741,
|
||||
"loss_ee": 0.9434934655825297,
|
||||
"loss_delivery": 0.9477877616882324,
|
||||
"loss_biodist": 0.7897172272205353,
|
||||
"loss_toxic": 0.018312191901107628,
|
||||
"acc_pdi": 0.7878787878787878,
|
||||
"acc_ee": 0.36363636363636365,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.4212222894032798,
|
||||
"loss_size": 0.08121616393327713,
|
||||
"loss_pdi": 0.5517565310001373,
|
||||
"loss_ee": 1.0685155193010967,
|
||||
"loss_delivery": 0.874200721581777,
|
||||
"loss_biodist": 0.828464408715566,
|
||||
"loss_toxic": 0.017068898615737755,
|
||||
"acc_pdi": 0.8181818181818182,
|
||||
"acc_ee": 0.25757575757575757,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.6395487785339355,
|
||||
"loss_size": 0.07888962080081303,
|
||||
"loss_pdi": 0.5913220842679342,
|
||||
"loss_ee": 1.0437468489011128,
|
||||
"loss_delivery": 1.0660852392514546,
|
||||
"loss_biodist": 0.8421931266784668,
|
||||
"loss_toxic": 0.01731194742023945,
|
||||
"acc_pdi": 0.7878787878787878,
|
||||
"acc_ee": 0.2727272727272727,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.5140305360158286,
|
||||
"loss_size": 0.0599971575041612,
|
||||
"loss_pdi": 0.5561938285827637,
|
||||
"loss_ee": 1.0674984057744343,
|
||||
"loss_delivery": 0.9739653070767721,
|
||||
"loss_biodist": 0.8400343159834543,
|
||||
"loss_toxic": 0.01634151643762986,
|
||||
"acc_pdi": 0.7424242424242424,
|
||||
"acc_ee": 0.22727272727272727,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.628753344217936,
|
||||
"loss_size": 0.08887146785855293,
|
||||
"loss_pdi": 0.582503984371821,
|
||||
"loss_ee": 1.114095131556193,
|
||||
"loss_delivery": 0.9745903412501017,
|
||||
"loss_biodist": 0.8516232868035635,
|
||||
"loss_toxic": 0.017069284183283646,
|
||||
"acc_pdi": 0.7121212121212122,
|
||||
"acc_ee": 0.19696969696969696,
|
||||
"acc_toxic": 1.0
|
||||
},
|
||||
{
|
||||
"loss": 3.6391177972157798,
|
||||
"loss_size": 0.08282352735598882,
|
||||
"loss_pdi": 0.581497848033905,
|
||||
"loss_ee": 1.141685386498769,
|
||||
"loss_delivery": 0.9548555612564087,
|
||||
"loss_biodist": 0.8643284440040588,
|
||||
"loss_toxic": 0.013927079737186432,
|
||||
"acc_pdi": 0.7424242424242424,
|
||||
"acc_ee": 0.16666666666666666,
|
||||
"acc_toxic": 1.0
|
||||
}
|
||||
]
|
||||
}
|
||||
BIN
models/finetune_cv/fold_4/model.pt
Normal file
BIN
models/finetune_cv/fold_4/model.pt
Normal file
Binary file not shown.
294
models/finetune_cv/test_results.json
Normal file
294
models/finetune_cv/test_results.json
Normal file
@ -0,0 +1,294 @@
|
||||
{
|
||||
"fold_results": [
|
||||
{
|
||||
"fold_idx": 0,
|
||||
"n_samples": 95,
|
||||
"size": {
|
||||
"n": 95,
|
||||
"rmse": 0.5909209144067168,
|
||||
"mae": 0.376253614927593,
|
||||
"r2": 0.005712927161997228
|
||||
},
|
||||
"delivery": {
|
||||
"n": 66,
|
||||
"rmse": 1.3280577883458438,
|
||||
"mae": 0.5195405159964028,
|
||||
"r2": 0.03195999739694366
|
||||
},
|
||||
"pdi": {
|
||||
"n": 95,
|
||||
"accuracy": 0.6105263157894737,
|
||||
"precision": 0.20350877192982456,
|
||||
"recall": 0.3333333333333333,
|
||||
"f1": 0.25272331154684097
|
||||
},
|
||||
"ee": {
|
||||
"n": 95,
|
||||
"accuracy": 0.6736842105263158,
|
||||
"precision": 0.22456140350877193,
|
||||
"recall": 0.3333333333333333,
|
||||
"f1": 0.26834381551362685
|
||||
},
|
||||
"toxic": {
|
||||
"n": 66,
|
||||
"accuracy": 0.8939393939393939,
|
||||
"precision": 0.44696969696969696,
|
||||
"recall": 0.5,
|
||||
"f1": 0.472
|
||||
},
|
||||
"biodist": {
|
||||
"n": 66,
|
||||
"kl_divergence": 0.851655784204727,
|
||||
"js_divergence": 0.21404831573756974
|
||||
}
|
||||
},
|
||||
{
|
||||
"fold_idx": 1,
|
||||
"n_samples": 195,
|
||||
"size": {
|
||||
"n": 193,
|
||||
"rmse": 0.4425801645813746,
|
||||
"mae": 0.26432527161632796,
|
||||
"r2": -0.026225211870033682
|
||||
},
|
||||
"delivery": {
|
||||
"n": 123,
|
||||
"rmse": 0.7771322048436382,
|
||||
"mae": 0.6133777339870822,
|
||||
"r2": -0.128644776760948
|
||||
},
|
||||
"pdi": {
|
||||
"n": 195,
|
||||
"accuracy": 0.7076923076923077,
|
||||
"precision": 0.35384615384615387,
|
||||
"recall": 0.5,
|
||||
"f1": 0.4144144144144144
|
||||
},
|
||||
"ee": {
|
||||
"n": 195,
|
||||
"accuracy": 0.4205128205128205,
|
||||
"precision": 0.14017094017094017,
|
||||
"recall": 0.3333333333333333,
|
||||
"f1": 0.19735258724428398
|
||||
},
|
||||
"toxic": {
|
||||
"n": 123,
|
||||
"accuracy": 1.0,
|
||||
"precision": 1.0,
|
||||
"recall": 1.0,
|
||||
"f1": 1.0
|
||||
},
|
||||
"biodist": {
|
||||
"n": 123,
|
||||
"kl_divergence": 0.9336461102028436,
|
||||
"js_divergence": 0.24870266224462317
|
||||
}
|
||||
},
|
||||
{
|
||||
"fold_idx": 2,
|
||||
"n_samples": 51,
|
||||
"size": {
|
||||
"n": 51,
|
||||
"rmse": 0.6473513298834871,
|
||||
"mae": 0.5600235602434944,
|
||||
"r2": -9.27515642706235
|
||||
},
|
||||
"delivery": {
|
||||
"n": 44,
|
||||
"rmse": 0.7721077356414991,
|
||||
"mae": 0.6167582499593581,
|
||||
"r2": -0.4822886602727561
|
||||
},
|
||||
"pdi": {
|
||||
"n": 51,
|
||||
"accuracy": 0.8823529411764706,
|
||||
"precision": 0.29411764705882354,
|
||||
"recall": 0.3333333333333333,
|
||||
"f1": 0.3125
|
||||
},
|
||||
"ee": {
|
||||
"n": 51,
|
||||
"accuracy": 0.8431372549019608,
|
||||
"precision": 0.28104575163398693,
|
||||
"recall": 0.3333333333333333,
|
||||
"f1": 0.3049645390070922
|
||||
},
|
||||
"toxic": {
|
||||
"n": 47,
|
||||
"accuracy": 0.851063829787234,
|
||||
"precision": 0.425531914893617,
|
||||
"recall": 0.5,
|
||||
"f1": 0.4597701149425288
|
||||
},
|
||||
"biodist": {
|
||||
"n": 45,
|
||||
"kl_divergence": 1.1049896129018548,
|
||||
"js_divergence": 0.25485248115851133
|
||||
}
|
||||
},
|
||||
{
|
||||
"fold_idx": 3,
|
||||
"n_samples": 66,
|
||||
"size": {
|
||||
"n": 66,
|
||||
"rmse": 0.2407212117920812,
|
||||
"mae": 0.19363613562150436,
|
||||
"r2": -0.11204941379936861
|
||||
},
|
||||
"delivery": {
|
||||
"n": 62,
|
||||
"rmse": 1.0041711455927012,
|
||||
"mae": 0.7132550483914993,
|
||||
"r2": -0.63265374674746
|
||||
},
|
||||
"pdi": {
|
||||
"n": 66,
|
||||
"accuracy": 0.8484848484848485,
|
||||
"precision": 0.42424242424242425,
|
||||
"recall": 0.5,
|
||||
"f1": 0.4590163934426229
|
||||
},
|
||||
"ee": {
|
||||
"n": 66,
|
||||
"accuracy": 0.8181818181818182,
|
||||
"precision": 0.27692307692307694,
|
||||
"recall": 0.32727272727272727,
|
||||
"f1": 0.3
|
||||
},
|
||||
"toxic": {
|
||||
"n": 62,
|
||||
"accuracy": 1.0,
|
||||
"precision": 1.0,
|
||||
"recall": 1.0,
|
||||
"f1": 1.0
|
||||
},
|
||||
"biodist": {
|
||||
"n": 62,
|
||||
"kl_divergence": 0.9677978984139058,
|
||||
"js_divergence": 0.2020309307244639
|
||||
}
|
||||
},
|
||||
{
|
||||
"fold_idx": 4,
|
||||
"n_samples": 27,
|
||||
"size": {
|
||||
"n": 27,
|
||||
"rmse": 0.23392834445509142,
|
||||
"mae": 0.19066280788845485,
|
||||
"r2": -0.2667651950955112
|
||||
},
|
||||
"delivery": {
|
||||
"n": 15,
|
||||
"rmse": 1.9603892288630869,
|
||||
"mae": 1.3892907698949177,
|
||||
"r2": -0.29760739742916287
|
||||
},
|
||||
"pdi": {
|
||||
"n": 27,
|
||||
"accuracy": 0.8888888888888888,
|
||||
"precision": 0.4444444444444444,
|
||||
"recall": 0.5,
|
||||
"f1": 0.47058823529411764
|
||||
},
|
||||
"ee": {
|
||||
"n": 27,
|
||||
"accuracy": 0.5925925925925926,
|
||||
"precision": 0.19753086419753085,
|
||||
"recall": 0.3333333333333333,
|
||||
"f1": 0.24806201550387597
|
||||
},
|
||||
"toxic": {
|
||||
"n": 15,
|
||||
"accuracy": 1.0,
|
||||
"precision": 1.0,
|
||||
"recall": 1.0,
|
||||
"f1": 1.0
|
||||
},
|
||||
"biodist": {
|
||||
"n": 15,
|
||||
"kl_divergence": 0.9389607012315264,
|
||||
"js_divergence": 0.2470218476598176
|
||||
}
|
||||
}
|
||||
],
|
||||
"summary_stats": {
|
||||
"size": {
|
||||
"rmse_mean": 0.43110039302375025,
|
||||
"rmse_std": 0.17179051271013462,
|
||||
"r2_mean": -1.9348966641330534,
|
||||
"r2_std": 3.6713441784129
|
||||
},
|
||||
"delivery": {
|
||||
"rmse_mean": 1.1683716206573538,
|
||||
"rmse_std": 0.4449374578352648,
|
||||
"r2_mean": -0.30184691676267666,
|
||||
"r2_std": 0.23809090378746706
|
||||
},
|
||||
"pdi": {
|
||||
"accuracy_mean": 0.7875890604063979,
|
||||
"accuracy_std": 0.11016791908756088,
|
||||
"f1_mean": 0.3818484709395992,
|
||||
"f1_std": 0.08529090446864619
|
||||
},
|
||||
"ee": {
|
||||
"accuracy_mean": 0.6696217393431015,
|
||||
"accuracy_std": 0.15503740047242787,
|
||||
"f1_mean": 0.2637445914537758,
|
||||
"f1_std": 0.039213602228007696
|
||||
},
|
||||
"toxic": {
|
||||
"accuracy_mean": 0.9490006447453256,
|
||||
"accuracy_std": 0.06391582554207781,
|
||||
"f1_mean": 0.7863540229885058,
|
||||
"f1_std": 0.26169039387919035
|
||||
},
|
||||
"biodist": {
|
||||
"kl_mean": 0.9594100213909715,
|
||||
"kl_std": 0.08240959093662605,
|
||||
"js_mean": 0.23333124750499712,
|
||||
"js_std": 0.021158533549255752
|
||||
}
|
||||
},
|
||||
"overall": {
|
||||
"size": {
|
||||
"n_samples": 432,
|
||||
"mse": 0.22604480336185886,
|
||||
"rmse": 0.47544169291497657,
|
||||
"mae": 0.3084443360567093,
|
||||
"r2": -0.2313078534105617
|
||||
},
|
||||
"delivery": {
|
||||
"n_samples": 310,
|
||||
"mse": 1.0873755440675295,
|
||||
"rmse": 1.0427730069710903,
|
||||
"mae": 0.6513989447841361,
|
||||
"r2": -0.09443640807387799
|
||||
},
|
||||
"pdi": {
|
||||
"n_samples": 434,
|
||||
"accuracy": 0.7396313364055299,
|
||||
"precision": 0.18490783410138248,
|
||||
"recall": 0.25,
|
||||
"f1": 0.21258278145695364
|
||||
},
|
||||
"ee": {
|
||||
"n_samples": 434,
|
||||
"accuracy": 0.5967741935483871,
|
||||
"precision": 0.1993841416474211,
|
||||
"recall": 0.33205128205128204,
|
||||
"f1": 0.24915824915824913
|
||||
},
|
||||
"toxic": {
|
||||
"n_samples": 313,
|
||||
"accuracy": 0.9552715654952076,
|
||||
"precision": 0.4776357827476038,
|
||||
"recall": 0.5,
|
||||
"f1": 0.48856209150326796
|
||||
},
|
||||
"biodist": {
|
||||
"n_samples": 311,
|
||||
"kl_divergence": 0.9481034280166569,
|
||||
"js_divergence": 0.23285280825310384
|
||||
}
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@ -1,310 +1,206 @@
|
||||
{
|
||||
"train": [
|
||||
{
|
||||
"loss": 0.8244676398801744,
|
||||
"n_samples": 8721
|
||||
"loss": 0.7730368412685099,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.6991508170533461,
|
||||
"n_samples": 8721
|
||||
"loss": 0.658895703010919,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.6388374940987616,
|
||||
"n_samples": 8721
|
||||
"loss": 0.6059015260392299,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.6008581508669937,
|
||||
"n_samples": 8721
|
||||
"loss": 0.5744731174349416,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.584832567446085,
|
||||
"n_samples": 8721
|
||||
"loss": 0.5452056020458733,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.5481657371815157,
|
||||
"n_samples": 8721
|
||||
"loss": 0.5138543470936083,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.5368926340308079,
|
||||
"n_samples": 8721
|
||||
"loss": 0.4885380559178135,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.5210388793613561,
|
||||
"n_samples": 8721
|
||||
"loss": 0.47587182296687974,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.49758357966374045,
|
||||
"n_samples": 8721
|
||||
"loss": 0.4671051038255316,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.49256294099457043,
|
||||
"n_samples": 8721
|
||||
"loss": 0.46794115915756107,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.4697267088016886,
|
||||
"n_samples": 8721
|
||||
"loss": 0.4293930456997915,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.45763822707571084,
|
||||
"n_samples": 8721
|
||||
"loss": 0.42624105651716415,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.4495221330627172,
|
||||
"n_samples": 8721
|
||||
"loss": 0.4131358770446828,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.446159594079631,
|
||||
"n_samples": 8721
|
||||
"loss": 0.3946074267790835,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.4327090857889029,
|
||||
"n_samples": 8721
|
||||
"loss": 0.3898155013755344,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.4249273364101852,
|
||||
"n_samples": 8721
|
||||
"loss": 0.37861797005733383,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.4216959138704459,
|
||||
"n_samples": 8721
|
||||
"loss": 0.3775682858392304,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.416526201182502,
|
||||
"n_samples": 8721
|
||||
"loss": 0.3800349080262064,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.40368679039741573,
|
||||
"n_samples": 8721
|
||||
"loss": 0.36302345173031675,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.4051084730032182,
|
||||
"n_samples": 8721
|
||||
"loss": 0.3429561740842766,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.38971701020385785,
|
||||
"n_samples": 8721
|
||||
"loss": 0.3445638883004898,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.39155546386038786,
|
||||
"n_samples": 8721
|
||||
"loss": 0.318970229203733,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.37976963541784114,
|
||||
"n_samples": 8721
|
||||
"loss": 0.30179278279904437,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.36484339719805037,
|
||||
"n_samples": 8721
|
||||
"loss": 0.2887343142006437,
|
||||
"n_samples": 6783
|
||||
},
|
||||
{
|
||||
"loss": 0.36232607571196496,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3345973272380199,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.31767916518768957,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.32065429246052457,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3171297926146043,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3122120894173009,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.3135035038404461,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.2987745178222875,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.2914867957853393,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.2983839795507705,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.2826709597875678,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.2731766632569382,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.27726896305742266,
|
||||
"n_samples": 8721
|
||||
},
|
||||
{
|
||||
"loss": 0.27864557847067956,
|
||||
"n_samples": 8721
|
||||
"loss": 0.29240367556855545,
|
||||
"n_samples": 6783
|
||||
}
|
||||
],
|
||||
"val": [
|
||||
{
|
||||
"loss": 0.7601077516012517,
|
||||
"n_samples": 969
|
||||
"loss": 0.7350345371841441,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.7119935319611901,
|
||||
"n_samples": 969
|
||||
"loss": 0.7165568811318536,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.6461842978148269,
|
||||
"n_samples": 969
|
||||
"loss": 0.7251406249862214,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.7006978391063226,
|
||||
"n_samples": 969
|
||||
"loss": 0.6836505264587159,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.6533874032943979,
|
||||
"n_samples": 969
|
||||
"loss": 0.6747132955771933,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.6413641451743611,
|
||||
"n_samples": 969
|
||||
"loss": 0.6691136244936912,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.6168395132979742,
|
||||
"n_samples": 969
|
||||
"loss": 0.6337480902323249,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.6095251602162025,
|
||||
"n_samples": 969
|
||||
"loss": 0.6600317959527934,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5887809592626905,
|
||||
"n_samples": 969
|
||||
"loss": 0.6439923948855346,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5655298325376368,
|
||||
"n_samples": 969
|
||||
"loss": 0.643800035575267,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5809201743872788,
|
||||
"n_samples": 969
|
||||
"loss": 0.6181512585221839,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5897585974912033,
|
||||
"n_samples": 969
|
||||
"loss": 0.6442458634939151,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5732012489662573,
|
||||
"n_samples": 969
|
||||
"loss": 0.6344759362359862,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5607388911786094,
|
||||
"n_samples": 969
|
||||
"loss": 0.6501405371457472,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5717580675371414,
|
||||
"n_samples": 969
|
||||
"loss": 0.6098835162990152,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5553950037657291,
|
||||
"n_samples": 969
|
||||
"loss": 0.6366627322138894,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5778171792857049,
|
||||
"n_samples": 969
|
||||
"loss": 0.6171610150646417,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5602665468127734,
|
||||
"n_samples": 969
|
||||
"loss": 0.6358801012273748,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5475307451359259,
|
||||
"n_samples": 969
|
||||
"loss": 0.6239976831059871,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.551515599314827,
|
||||
"n_samples": 969
|
||||
"loss": 0.6683828232827201,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5755438121541243,
|
||||
"n_samples": 969
|
||||
"loss": 0.6655785786478143,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5798238261811381,
|
||||
"n_samples": 969
|
||||
"loss": 0.6152775046503088,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5739961433828923,
|
||||
"n_samples": 969
|
||||
"loss": 0.6202247662153858,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5742599932540312,
|
||||
"n_samples": 969
|
||||
"loss": 0.648199727435189,
|
||||
"n_samples": 2907
|
||||
},
|
||||
{
|
||||
"loss": 0.5834948123885382,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.554078846570139,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5714933996322354,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5384107524350331,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.570854394451568,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5767292551642478,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5660079547556808,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5608972411514312,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5620947442987263,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5706970894361305,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5702376298690974,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5758474825259579,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5673816067284844,
|
||||
"n_samples": 969
|
||||
},
|
||||
{
|
||||
"loss": 0.5671441179879925,
|
||||
"n_samples": 969
|
||||
"loss": 0.6473217075085124,
|
||||
"n_samples": 2907
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -1,232 +1,224 @@
|
||||
"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式"""
|
||||
"""内部数据 Cross-Validation 划分脚本:基于 Amine 的分组划分"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import typer
|
||||
from loguru import logger
|
||||
|
||||
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
|
||||
from lnp_ml.config import INTERIM_DATA_DIR, PROCESSED_DATA_DIR
|
||||
from lnp_ml.dataset import (
|
||||
LNPDatasetConfig,
|
||||
process_dataframe,
|
||||
SMILES_COL,
|
||||
COMP_COLS,
|
||||
HELP_COLS,
|
||||
TARGET_REGRESSION,
|
||||
TARGET_CLASSIFICATION_PDI,
|
||||
TARGET_CLASSIFICATION_EE,
|
||||
TARGET_TOXIC,
|
||||
TARGET_BIODIST,
|
||||
get_phys_cols,
|
||||
get_exp_cols,
|
||||
EXP_ONEHOT_SPECS,
|
||||
PHYS_ONEHOT_SPECS,
|
||||
)
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
# CV extra_x 列名到模型列名的映射
|
||||
CV_COL_MAPPING = {
|
||||
# Batch_or_individual_or_barcoded -> Sample_organization_type (for Value_name related)
|
||||
"Batch_or_individual_or_barcoded_Barcoded": "Batch_or_individual_or_barcoded_Barcoded",
|
||||
"Batch_or_individual_or_barcoded_Individual": "Batch_or_individual_or_barcoded_Individual",
|
||||
# Helper_lipid_ID_None 不在模型中使用,忽略
|
||||
}
|
||||
|
||||
|
||||
def load_cv_split(cv_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
def amine_based_cv_split(
|
||||
df: pd.DataFrame,
|
||||
n_folds: int = 5,
|
||||
seed: int = 42,
|
||||
amine_col: str = "Amine",
|
||||
) -> List[dict]:
|
||||
"""
|
||||
加载单个 CV split 的数据。
|
||||
基于 Amine 列进行 Cross-Validation 划分。
|
||||
|
||||
步骤:
|
||||
1. 按 amine_col 分组
|
||||
2. 打乱分组顺序
|
||||
3. 将分组 round-robin 分配到 n_folds 个容器
|
||||
4. 对于每个 fold i:
|
||||
- validation = container[i]
|
||||
- test = container[(i+1) % n_folds]
|
||||
- train = 其余所有
|
||||
|
||||
Args:
|
||||
cv_dir: CV split 目录,如 cv_0/
|
||||
df: 输入 DataFrame
|
||||
n_folds: 折数
|
||||
seed: 随机种子
|
||||
amine_col: 用于分组的列名
|
||||
|
||||
Returns:
|
||||
(train_df, valid_df, test_df) 合并后的 DataFrame
|
||||
List of dicts,每个 dict 包含 train_df, val_df, test_df
|
||||
"""
|
||||
splits = {}
|
||||
for split_name in ["train", "valid", "test"]:
|
||||
# 加载主数据(smiles, quantified_delivery)
|
||||
main_path = cv_dir / f"{split_name}.csv"
|
||||
extra_x_path = cv_dir / f"{split_name}_extra_x.csv"
|
||||
metadata_path = cv_dir / f"{split_name}_metadata.csv"
|
||||
# 获取唯一的 amine 并打乱
|
||||
unique_amines = df[amine_col].unique()
|
||||
rng = np.random.RandomState(seed)
|
||||
rng.shuffle(unique_amines)
|
||||
|
||||
logger.info(f"Found {len(unique_amines)} unique amines")
|
||||
|
||||
# Round-robin 分配到 n_folds 个容器
|
||||
containers = [[] for _ in range(n_folds)]
|
||||
for i, amine in enumerate(unique_amines):
|
||||
containers[i % n_folds].append(amine)
|
||||
|
||||
# 打印每个容器的大小
|
||||
for i, container in enumerate(containers):
|
||||
container_samples = df[df[amine_col].isin(container)]
|
||||
logger.info(f" Container {i}: {len(container)} amines, {len(container_samples)} samples")
|
||||
|
||||
# 生成每个 fold 的数据
|
||||
fold_splits = []
|
||||
for i in range(n_folds):
|
||||
val_amines = set(containers[i])
|
||||
test_amines = set(containers[(i + 1) % n_folds])
|
||||
train_amines = set()
|
||||
for j in range(n_folds):
|
||||
if j != i and j != (i + 1) % n_folds:
|
||||
train_amines.update(containers[j])
|
||||
|
||||
if not main_path.exists():
|
||||
raise FileNotFoundError(f"Missing {main_path}")
|
||||
train_df = df[df[amine_col].isin(train_amines)].reset_index(drop=True)
|
||||
val_df = df[df[amine_col].isin(val_amines)].reset_index(drop=True)
|
||||
test_df = df[df[amine_col].isin(test_amines)].reset_index(drop=True)
|
||||
|
||||
main_df = pd.read_csv(main_path)
|
||||
fold_splits.append({
|
||||
"train": train_df,
|
||||
"val": val_df,
|
||||
"test": test_df,
|
||||
})
|
||||
|
||||
# 加载 extra_x(已 one-hot 编码的特征)
|
||||
if extra_x_path.exists():
|
||||
extra_x_df = pd.read_csv(extra_x_path)
|
||||
# 确保行数一致
|
||||
assert len(main_df) == len(extra_x_df), f"Row count mismatch: {len(main_df)} vs {len(extra_x_df)}"
|
||||
# 合并(按行索引)
|
||||
df = pd.concat([main_df, extra_x_df], axis=1)
|
||||
else:
|
||||
df = main_df
|
||||
logger.warning(f" {split_name}_extra_x.csv not found, using main data only")
|
||||
|
||||
# 可选:从 metadata 获取额外信息
|
||||
if metadata_path.exists():
|
||||
metadata_df = pd.read_csv(metadata_path)
|
||||
# 提取需要的列(如 Purity, Mix_type, Value_name 等)
|
||||
for col in ["Purity", "Mix_type", "Value_name", "Target_or_delivered_gene"]:
|
||||
if col in metadata_df.columns and col not in df.columns:
|
||||
df[col] = metadata_df[col]
|
||||
|
||||
splits[split_name] = df
|
||||
logger.info(
|
||||
f"Fold {i}: train={len(train_df)} ({len(train_amines)} amines), "
|
||||
f"val={len(val_df)} ({len(val_amines)} amines), "
|
||||
f"test={len(test_df)} ({len(test_amines)} amines)"
|
||||
)
|
||||
|
||||
return splits["train"], splits["valid"], splits["test"]
|
||||
|
||||
|
||||
def process_cv_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
处理 CV 数据的 DataFrame,对齐到模型所需的列格式。
|
||||
|
||||
CV 数据的 extra_x 已经包含大部分 one-hot 编码,但需要:
|
||||
1. 添加缺失的 one-hot 列(设为 0)
|
||||
2. 从 metadata 中生成 phys token 的 one-hot 列(Purity, Mix_type, Cargo_type, Target_or_delivered_gene)
|
||||
3. 生成 Value_name 的 one-hot 列
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# 1. 处理 comp 列
|
||||
for col in COMP_COLS:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
|
||||
else:
|
||||
df[col] = 0.0
|
||||
|
||||
# 2. 处理 help 列
|
||||
for col in HELP_COLS:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0
|
||||
else:
|
||||
df[col] = df[col].fillna(0.0).astype(float)
|
||||
|
||||
# 3. 处理 phys token 的 one-hot 列
|
||||
for col, values in PHYS_ONEHOT_SPECS.items():
|
||||
for v in values:
|
||||
onehot_col = f"{col}_{v}"
|
||||
if onehot_col not in df.columns:
|
||||
# 尝试从原始列生成
|
||||
if col in df.columns:
|
||||
df[onehot_col] = (df[col] == v).astype(float)
|
||||
else:
|
||||
df[onehot_col] = 0.0
|
||||
else:
|
||||
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
|
||||
|
||||
# 4. 处理 exp token 的 one-hot 列
|
||||
for col, values in EXP_ONEHOT_SPECS.items():
|
||||
for v in values:
|
||||
onehot_col = f"{col}_{v}"
|
||||
if onehot_col not in df.columns:
|
||||
# 尝试从原始列生成
|
||||
if col in df.columns:
|
||||
df[onehot_col] = (df[col] == v).astype(float)
|
||||
else:
|
||||
df[onehot_col] = 0.0
|
||||
else:
|
||||
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
|
||||
|
||||
# 5. 处理 quantified_delivery
|
||||
if "quantified_delivery" in df.columns:
|
||||
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def get_feature_columns() -> List[str]:
|
||||
"""获取所有特征列名"""
|
||||
config = LNPDatasetConfig()
|
||||
return (
|
||||
["smiles"]
|
||||
+ config.comp_cols
|
||||
+ config.phys_cols
|
||||
+ config.help_cols
|
||||
+ config.exp_cols
|
||||
+ ["quantified_delivery"]
|
||||
)
|
||||
return fold_splits
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
|
||||
input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv",
|
||||
output_dir: Path = PROCESSED_DATA_DIR / "cv",
|
||||
n_folds: int = 5,
|
||||
seed: int = 42,
|
||||
amine_col: str = "Amine",
|
||||
):
|
||||
"""
|
||||
处理 cross-validation 数据,生成模型所需的 parquet 文件。
|
||||
基于 Amine 分组进行 Cross-Validation 数据划分。
|
||||
|
||||
采用类似 scaffold splitting 的思路,将相同 Amine 的数据放在同一组,
|
||||
确保训练集和测试集之间没有 Amine 泄露。
|
||||
|
||||
划分比例约为 train:val:test ≈ 3:1:1
|
||||
|
||||
输出结构:
|
||||
- processed/cv/fold_0/train.parquet
|
||||
- processed/cv/fold_0/valid.parquet
|
||||
- processed/cv/fold_0/val.parquet
|
||||
- processed/cv/fold_0/test.parquet
|
||||
- processed/cv/fold_1/...
|
||||
- processed/cv/feature_columns.txt
|
||||
"""
|
||||
logger.info(f"Processing CV data from {data_dir}")
|
||||
logger.info(f"Loading data from {input_path}")
|
||||
df = pd.read_csv(input_path)
|
||||
logger.info(f"Loaded {len(df)} samples")
|
||||
|
||||
# 获取所有 cv_* 目录
|
||||
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
|
||||
|
||||
if len(cv_dirs) == 0:
|
||||
logger.error(f"No cv_* directories found in {data_dir}")
|
||||
# 检查 amine 列是否存在
|
||||
if amine_col not in df.columns:
|
||||
logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(df.columns)}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if len(cv_dirs) != n_folds:
|
||||
logger.warning(f"Expected {n_folds} folds, found {len(cv_dirs)}")
|
||||
# 处理数据(列对齐、one-hot 生成等)
|
||||
logger.info("Processing dataframe...")
|
||||
df = process_dataframe(df)
|
||||
|
||||
logger.info(f"Found {len(cv_dirs)} folds: {[d.name for d in cv_dirs]}")
|
||||
# 确保 Amine 列仍然存在(process_dataframe 可能不会保留它)
|
||||
# 重新加载原始数据获取 Amine 列
|
||||
original_df = pd.read_csv(input_path)
|
||||
if amine_col in original_df.columns and amine_col not in df.columns:
|
||||
df[amine_col] = original_df[amine_col].values
|
||||
|
||||
feature_cols = get_feature_columns()
|
||||
# 定义要保留的列
|
||||
phys_cols = get_phys_cols()
|
||||
exp_cols = get_exp_cols()
|
||||
|
||||
keep_cols = (
|
||||
[SMILES_COL]
|
||||
+ COMP_COLS
|
||||
+ phys_cols
|
||||
+ HELP_COLS
|
||||
+ exp_cols
|
||||
+ TARGET_REGRESSION
|
||||
+ TARGET_CLASSIFICATION_PDI
|
||||
+ TARGET_CLASSIFICATION_EE
|
||||
+ [TARGET_TOXIC]
|
||||
+ TARGET_BIODIST
|
||||
)
|
||||
|
||||
# 只保留存在的列
|
||||
keep_cols = [c for c in keep_cols if c in df.columns]
|
||||
|
||||
# 进行 CV 划分
|
||||
logger.info(f"\nPerforming {n_folds}-fold amine-based CV split (seed={seed})...")
|
||||
fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col)
|
||||
|
||||
# 保存每个 fold
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, cv_dir in enumerate(cv_dirs):
|
||||
fold_name = f"fold_{i}"
|
||||
fold_output_dir = output_dir / fold_name
|
||||
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Processing {cv_dir.name} -> {fold_name}")
|
||||
|
||||
# 加载数据
|
||||
train_df, valid_df, test_df = load_cv_split(cv_dir)
|
||||
|
||||
logger.info(f" Loaded: train={len(train_df)}, valid={len(valid_df)}, test={len(test_df)}")
|
||||
|
||||
# 处理数据
|
||||
train_df = process_cv_dataframe(train_df)
|
||||
valid_df = process_cv_dataframe(valid_df)
|
||||
test_df = process_cv_dataframe(test_df)
|
||||
|
||||
# 确保所有列存在
|
||||
for col in feature_cols:
|
||||
for df in [train_df, valid_df, test_df]:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0 if col != "smiles" else ""
|
||||
for i, split in enumerate(fold_splits):
|
||||
fold_dir = output_dir / f"fold_{i}"
|
||||
fold_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 只保留需要的列
|
||||
train_df = train_df[feature_cols]
|
||||
valid_df = valid_df[feature_cols]
|
||||
test_df = test_df[feature_cols]
|
||||
train_df = split["train"][keep_cols].reset_index(drop=True)
|
||||
val_df = split["val"][keep_cols].reset_index(drop=True)
|
||||
test_df = split["test"][keep_cols].reset_index(drop=True)
|
||||
|
||||
# 保存
|
||||
train_df.to_parquet(fold_output_dir / "train.parquet", index=False)
|
||||
valid_df.to_parquet(fold_output_dir / "valid.parquet", index=False)
|
||||
test_df.to_parquet(fold_output_dir / "test.parquet", index=False)
|
||||
train_df.to_parquet(fold_dir / "train.parquet", index=False)
|
||||
val_df.to_parquet(fold_dir / "val.parquet", index=False)
|
||||
test_df.to_parquet(fold_dir / "test.parquet", index=False)
|
||||
|
||||
logger.success(f" Saved to {fold_output_dir}")
|
||||
logger.success(f"Saved fold {i} to {fold_dir}")
|
||||
|
||||
# 保存特征列配置
|
||||
cols_path = output_dir / "feature_columns.txt"
|
||||
with open(cols_path, "w") as f:
|
||||
f.write("\n".join(feature_cols))
|
||||
logger.success(f"Saved feature columns to {cols_path}")
|
||||
# 保存列名配置
|
||||
config_path = output_dir / "feature_columns.txt"
|
||||
with open(config_path, "w") as f:
|
||||
f.write("# Feature columns configuration\n\n")
|
||||
f.write(f"# SMILES\n{SMILES_COL}\n\n")
|
||||
f.write(f"# comp token [{len(COMP_COLS)}]\n")
|
||||
f.write("\n".join(COMP_COLS) + "\n\n")
|
||||
f.write(f"# phys token [{len(phys_cols)}]\n")
|
||||
f.write("\n".join(phys_cols) + "\n\n")
|
||||
f.write(f"# help token [{len(HELP_COLS)}]\n")
|
||||
f.write("\n".join(HELP_COLS) + "\n\n")
|
||||
f.write(f"# exp token [{len(exp_cols)}]\n")
|
||||
f.write("\n".join(exp_cols) + "\n\n")
|
||||
f.write("# Targets\n")
|
||||
f.write("## Regression\n")
|
||||
f.write("\n".join(TARGET_REGRESSION) + "\n")
|
||||
f.write("## PDI classification\n")
|
||||
f.write("\n".join(TARGET_CLASSIFICATION_PDI) + "\n")
|
||||
f.write("## EE classification\n")
|
||||
f.write("\n".join(TARGET_CLASSIFICATION_EE) + "\n")
|
||||
f.write("## Toxic\n")
|
||||
f.write(f"{TARGET_TOXIC}\n")
|
||||
f.write("## Biodistribution\n")
|
||||
f.write("\n".join(TARGET_BIODIST) + "\n")
|
||||
|
||||
logger.success(f"Saved feature config to {config_path}")
|
||||
|
||||
# 打印汇总
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("CV DATA PROCESSING COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Output directory: {output_dir}")
|
||||
logger.info(f"Number of folds: {len(cv_dirs)}")
|
||||
logger.info(f"Number of folds: {n_folds}")
|
||||
logger.info(f"Splitting method: Amine-based (column: {amine_col})")
|
||||
logger.info(f"Random seed: {seed}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
234
scripts/process_external_cv.py
Normal file
234
scripts/process_external_cv.py
Normal file
@ -0,0 +1,234 @@
|
||||
"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import typer
|
||||
from loguru import logger
|
||||
|
||||
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
|
||||
from lnp_ml.dataset import (
|
||||
LNPDatasetConfig,
|
||||
COMP_COLS,
|
||||
HELP_COLS,
|
||||
get_phys_cols,
|
||||
get_exp_cols,
|
||||
EXP_ONEHOT_SPECS,
|
||||
PHYS_ONEHOT_SPECS,
|
||||
)
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
# CV extra_x 列名到模型列名的映射
|
||||
CV_COL_MAPPING = {
|
||||
# Batch_or_individual_or_barcoded -> Sample_organization_type (for Value_name related)
|
||||
"Batch_or_individual_or_barcoded_Barcoded": "Batch_or_individual_or_barcoded_Barcoded",
|
||||
"Batch_or_individual_or_barcoded_Individual": "Batch_or_individual_or_barcoded_Individual",
|
||||
# Helper_lipid_ID_None 不在模型中使用,忽略
|
||||
}
|
||||
|
||||
|
||||
def load_cv_split(cv_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
加载单个 CV split 的数据。
|
||||
|
||||
Args:
|
||||
cv_dir: CV split 目录,如 cv_0/
|
||||
|
||||
Returns:
|
||||
(train_df, valid_df, test_df) 合并后的 DataFrame
|
||||
"""
|
||||
splits = {}
|
||||
for split_name in ["train", "valid", "test"]:
|
||||
# 加载主数据(smiles, quantified_delivery)
|
||||
main_path = cv_dir / f"{split_name}.csv"
|
||||
extra_x_path = cv_dir / f"{split_name}_extra_x.csv"
|
||||
metadata_path = cv_dir / f"{split_name}_metadata.csv"
|
||||
|
||||
if not main_path.exists():
|
||||
raise FileNotFoundError(f"Missing {main_path}")
|
||||
|
||||
main_df = pd.read_csv(main_path)
|
||||
|
||||
# 加载 extra_x(已 one-hot 编码的特征)
|
||||
if extra_x_path.exists():
|
||||
extra_x_df = pd.read_csv(extra_x_path)
|
||||
# 确保行数一致
|
||||
assert len(main_df) == len(extra_x_df), f"Row count mismatch: {len(main_df)} vs {len(extra_x_df)}"
|
||||
# 合并(按行索引)
|
||||
df = pd.concat([main_df, extra_x_df], axis=1)
|
||||
else:
|
||||
df = main_df
|
||||
logger.warning(f" {split_name}_extra_x.csv not found, using main data only")
|
||||
|
||||
# 可选:从 metadata 获取额外信息
|
||||
if metadata_path.exists():
|
||||
metadata_df = pd.read_csv(metadata_path)
|
||||
# 提取需要的列(如 Purity, Mix_type, Value_name 等)
|
||||
for col in ["Purity", "Mix_type", "Value_name", "Target_or_delivered_gene"]:
|
||||
if col in metadata_df.columns and col not in df.columns:
|
||||
df[col] = metadata_df[col]
|
||||
|
||||
splits[split_name] = df
|
||||
|
||||
return splits["train"], splits["valid"], splits["test"]
|
||||
|
||||
|
||||
def process_cv_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
处理 CV 数据的 DataFrame,对齐到模型所需的列格式。
|
||||
|
||||
CV 数据的 extra_x 已经包含大部分 one-hot 编码,但需要:
|
||||
1. 添加缺失的 one-hot 列(设为 0)
|
||||
2. 从 metadata 中生成 phys token 的 one-hot 列(Purity, Mix_type, Cargo_type, Target_or_delivered_gene)
|
||||
3. 生成 Value_name 的 one-hot 列
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# 1. 处理 comp 列
|
||||
for col in COMP_COLS:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
|
||||
else:
|
||||
df[col] = 0.0
|
||||
|
||||
# 2. 处理 help 列
|
||||
for col in HELP_COLS:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0
|
||||
else:
|
||||
df[col] = df[col].fillna(0.0).astype(float)
|
||||
|
||||
# 3. 处理 phys token 的 one-hot 列
|
||||
for col, values in PHYS_ONEHOT_SPECS.items():
|
||||
for v in values:
|
||||
onehot_col = f"{col}_{v}"
|
||||
if onehot_col not in df.columns:
|
||||
# 尝试从原始列生成
|
||||
if col in df.columns:
|
||||
df[onehot_col] = (df[col] == v).astype(float)
|
||||
else:
|
||||
df[onehot_col] = 0.0
|
||||
else:
|
||||
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
|
||||
|
||||
# 4. 处理 exp token 的 one-hot 列
|
||||
for col, values in EXP_ONEHOT_SPECS.items():
|
||||
for v in values:
|
||||
onehot_col = f"{col}_{v}"
|
||||
if onehot_col not in df.columns:
|
||||
# 尝试从原始列生成
|
||||
if col in df.columns:
|
||||
df[onehot_col] = (df[col] == v).astype(float)
|
||||
else:
|
||||
df[onehot_col] = 0.0
|
||||
else:
|
||||
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
|
||||
|
||||
# 5. 处理 quantified_delivery
|
||||
if "quantified_delivery" in df.columns:
|
||||
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def get_feature_columns() -> List[str]:
|
||||
"""获取所有特征列名"""
|
||||
config = LNPDatasetConfig()
|
||||
return (
|
||||
["smiles"]
|
||||
+ config.comp_cols
|
||||
+ config.phys_cols
|
||||
+ config.help_cols
|
||||
+ config.exp_cols
|
||||
+ ["quantified_delivery"]
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
|
||||
output_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv",
|
||||
n_folds: int = 5,
|
||||
):
|
||||
"""
|
||||
处理 cross-validation 数据,生成模型所需的 parquet 文件。
|
||||
|
||||
输出结构:
|
||||
- processed/pretrain_cv/fold_0/train.parquet
|
||||
- processed/pretrain_cv/fold_0/valid.parquet
|
||||
- processed/pretrain_cv/fold_0/test.parquet
|
||||
- processed/pretrain_cv/fold_1/...
|
||||
- processed/pretrain_cv/feature_columns.txt
|
||||
"""
|
||||
logger.info(f"Processing CV data from {data_dir}")
|
||||
|
||||
# 获取所有 cv_* 目录
|
||||
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
|
||||
|
||||
if len(cv_dirs) == 0:
|
||||
logger.error(f"No cv_* directories found in {data_dir}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if len(cv_dirs) != n_folds:
|
||||
logger.warning(f"Expected {n_folds} folds, found {len(cv_dirs)}")
|
||||
|
||||
logger.info(f"Found {len(cv_dirs)} folds: {[d.name for d in cv_dirs]}")
|
||||
|
||||
feature_cols = get_feature_columns()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, cv_dir in enumerate(cv_dirs):
|
||||
fold_name = f"fold_{i}"
|
||||
fold_output_dir = output_dir / fold_name
|
||||
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Processing {cv_dir.name} -> {fold_name}")
|
||||
|
||||
# 加载数据
|
||||
train_df, valid_df, test_df = load_cv_split(cv_dir)
|
||||
|
||||
logger.info(f" Loaded: train={len(train_df)}, valid={len(valid_df)}, test={len(test_df)}")
|
||||
|
||||
# 处理数据
|
||||
train_df = process_cv_dataframe(train_df)
|
||||
valid_df = process_cv_dataframe(valid_df)
|
||||
test_df = process_cv_dataframe(test_df)
|
||||
|
||||
# 确保所有列存在
|
||||
for col in feature_cols:
|
||||
for df in [train_df, valid_df, test_df]:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0 if col != "smiles" else ""
|
||||
|
||||
# 只保留需要的列
|
||||
train_df = train_df[feature_cols]
|
||||
valid_df = valid_df[feature_cols]
|
||||
test_df = test_df[feature_cols]
|
||||
|
||||
# 保存
|
||||
train_df.to_parquet(fold_output_dir / "train.parquet", index=False)
|
||||
valid_df.to_parquet(fold_output_dir / "valid.parquet", index=False)
|
||||
test_df.to_parquet(fold_output_dir / "test.parquet", index=False)
|
||||
|
||||
logger.success(f" Saved to {fold_output_dir}")
|
||||
|
||||
# 保存特征列配置
|
||||
cols_path = output_dir / "feature_columns.txt"
|
||||
with open(cols_path, "w") as f:
|
||||
f.write("\n".join(feature_cols))
|
||||
logger.success(f"Saved feature columns to {cols_path}")
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("CV DATA PROCESSING COMPLETE")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Output directory: {output_dir}")
|
||||
logger.info(f"Number of folds: {len(cv_dirs)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user