diff --git a/Makefile b/Makefile index f0abbf9..b064847 100644 --- a/Makefile +++ b/Makefile @@ -73,6 +73,11 @@ data: requirements data_pretrain: requirements $(PYTHON_INTERPRETER) scripts/process_external.py +## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv) +.PHONY: data_cv +data_cv: requirements + $(PYTHON_INTERPRETER) scripts/process_data_cv.py + # MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder # 例如:make pretrain USE_MPNN=1 MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,) @@ -91,6 +96,16 @@ pretrain: requirements test_pretrain: requirements $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG) +## Pretrain with cross-validation (5-fold) +.PHONY: pretrain_cv +pretrain_cv: requirements + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG) + +## Evaluate CV pretrain models on test sets +.PHONY: test_cv +test_cv: requirements + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(MPNN_FLAG) + ## Train model (multi-task, from scratch) .PHONY: train train: requirements diff --git a/data/processed/cv/feature_columns.txt b/data/processed/cv/feature_columns.txt new file mode 100644 index 0000000..09d6735 --- /dev/null +++ b/data/processed/cv/feature_columns.txt @@ -0,0 +1,55 @@ +smiles +Cationic_Lipid_to_mRNA_weight_ratio +Cationic_Lipid_Mol_Ratio +Phospholipid_Mol_Ratio +Cholesterol_Mol_Ratio +PEG_Lipid_Mol_Ratio +Purity_Pure +Purity_Crude +Mix_type_Microfluidic +Mix_type_Pipetting +Cargo_type_mRNA +Cargo_type_pDNA +Cargo_type_siRNA +Target_or_delivered_gene_FFL +Target_or_delivered_gene_Peptide_barcode +Target_or_delivered_gene_hEPO +Target_or_delivered_gene_FVII +Target_or_delivered_gene_GFP +Helper_lipid_ID_DOPE +Helper_lipid_ID_DOTAP +Helper_lipid_ID_DSPC +Helper_lipid_ID_MDOA +Model_type_A549 +Model_type_BDMC +Model_type_BMDM +Model_type_HBEC_ALI +Model_type_HEK293T +Model_type_HeLa +Model_type_IGROV1 +Model_type_Mouse +Model_type_RAW264p7 +Delivery_target_body +Delivery_target_dendritic_cell +Delivery_target_generic_cell +Delivery_target_liver +Delivery_target_lung +Delivery_target_lung_epithelium +Delivery_target_macrophage +Delivery_target_muscle +Delivery_target_spleen +Route_of_administration_in_vitro +Route_of_administration_intramuscular +Route_of_administration_intratracheal +Route_of_administration_intravenous +Batch_or_individual_or_barcoded_Barcoded +Batch_or_individual_or_barcoded_Individual +Value_name_log_luminescence +Value_name_luminescence +Value_name_FFL_silencing +Value_name_Peptide_abundance +Value_name_hEPO +Value_name_FVII_silencing +Value_name_GFP_delivery +Value_name_Discretized_luminescence +quantified_delivery \ No newline at end of file diff --git a/data/processed/cv/fold_0/test.parquet b/data/processed/cv/fold_0/test.parquet new file mode 100644 index 0000000..064ece0 Binary files /dev/null and b/data/processed/cv/fold_0/test.parquet differ diff --git a/data/processed/cv/fold_0/train.parquet b/data/processed/cv/fold_0/train.parquet new file mode 100644 index 0000000..f593aea Binary files /dev/null and b/data/processed/cv/fold_0/train.parquet differ diff --git a/data/processed/cv/fold_0/valid.parquet b/data/processed/cv/fold_0/valid.parquet new file mode 100644 index 0000000..dbf73d9 Binary files /dev/null and b/data/processed/cv/fold_0/valid.parquet differ diff --git a/data/processed/cv/fold_1/test.parquet b/data/processed/cv/fold_1/test.parquet new file mode 100644 index 0000000..dbf73d9 Binary files /dev/null and b/data/processed/cv/fold_1/test.parquet differ diff --git a/data/processed/cv/fold_1/train.parquet b/data/processed/cv/fold_1/train.parquet new file mode 100644 index 0000000..3c1b676 Binary files /dev/null and b/data/processed/cv/fold_1/train.parquet differ diff --git a/data/processed/cv/fold_1/valid.parquet b/data/processed/cv/fold_1/valid.parquet new file mode 100644 index 0000000..5572ce3 Binary files /dev/null and b/data/processed/cv/fold_1/valid.parquet differ diff --git a/data/processed/cv/fold_2/test.parquet b/data/processed/cv/fold_2/test.parquet new file mode 100644 index 0000000..5572ce3 Binary files /dev/null and b/data/processed/cv/fold_2/test.parquet differ diff --git a/data/processed/cv/fold_2/train.parquet b/data/processed/cv/fold_2/train.parquet new file mode 100644 index 0000000..3679fb2 Binary files /dev/null and b/data/processed/cv/fold_2/train.parquet differ diff --git a/data/processed/cv/fold_2/valid.parquet b/data/processed/cv/fold_2/valid.parquet new file mode 100644 index 0000000..f91efae Binary files /dev/null and b/data/processed/cv/fold_2/valid.parquet differ diff --git a/data/processed/cv/fold_3/test.parquet b/data/processed/cv/fold_3/test.parquet new file mode 100644 index 0000000..f91efae Binary files /dev/null and b/data/processed/cv/fold_3/test.parquet differ diff --git a/data/processed/cv/fold_3/train.parquet b/data/processed/cv/fold_3/train.parquet new file mode 100644 index 0000000..f2834b8 Binary files /dev/null and b/data/processed/cv/fold_3/train.parquet differ diff --git a/data/processed/cv/fold_3/valid.parquet b/data/processed/cv/fold_3/valid.parquet new file mode 100644 index 0000000..d41cc4b Binary files /dev/null and b/data/processed/cv/fold_3/valid.parquet differ diff --git a/data/processed/cv/fold_4/test.parquet b/data/processed/cv/fold_4/test.parquet new file mode 100644 index 0000000..d41cc4b Binary files /dev/null and b/data/processed/cv/fold_4/test.parquet differ diff --git a/data/processed/cv/fold_4/train.parquet b/data/processed/cv/fold_4/train.parquet new file mode 100644 index 0000000..f214720 Binary files /dev/null and b/data/processed/cv/fold_4/train.parquet differ diff --git a/data/processed/cv/fold_4/valid.parquet b/data/processed/cv/fold_4/valid.parquet new file mode 100644 index 0000000..064ece0 Binary files /dev/null and b/data/processed/cv/fold_4/valid.parquet differ diff --git a/lnp_ml/modeling/pretrain_cv.py b/lnp_ml/modeling/pretrain_cv.py new file mode 100644 index 0000000..357382d --- /dev/null +++ b/lnp_ml/modeling/pretrain_cv.py @@ -0,0 +1,639 @@ +"""基于 Cross-Validation 的预训练脚本""" + +import json +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from loguru import logger +from tqdm import tqdm +from sklearn.metrics import mean_squared_error, r2_score +import typer + +from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR +from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn + + +# MPNN ensemble 默认路径 +DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON" + + +def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]: + """ + 自动查找 MPNN ensemble 的 model.pt 文件。 + """ + model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt")) + if not model_paths: + raise FileNotFoundError(f"No model.pt files found in {base_dir}") + return [str(p) for p in model_paths] + + +from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN + + +app = typer.Typer() + + +class EarlyStopping: + """早停机制""" + + def __init__(self, patience: int = 10, min_delta: float = 0.0): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.best_loss = float("inf") + + def __call__(self, val_loss: float) -> bool: + if val_loss < self.best_loss - self.min_delta: + self.best_loss = val_loss + self.counter = 0 + return False + self.counter += 1 + return self.counter >= self.patience + + +def warmup_cache(model: nn.Module, smiles_list: List[str], batch_size: int = 256) -> None: + """预热 RDKit 特征缓存""" + unique_smiles = list(set(smiles_list)) + logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...") + + for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"): + batch = unique_smiles[i:i + batch_size] + model.rdkit_encoder(batch) + + logger.success(f"Cache warmup complete. Cached {len(model.rdkit_encoder._cache)} SMILES.") + + +def train_epoch_delivery( + model: nn.Module, + loader: DataLoader, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int = 0, +) -> Dict[str, float]: + """单个 epoch 的训练(仅 delivery 任务)""" + model.train() + total_loss = 0.0 + n_samples = 0 + + pbar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]", leave=False) + for batch in pbar: + smiles = batch["smiles"] + tabular = {k: v.to(device) for k, v in batch["tabular"].items()} + targets = batch["targets"]["delivery"].to(device) + mask = batch["mask"]["delivery"].to(device) + + optimizer.zero_grad() + + pred = model.forward_delivery(smiles, tabular).squeeze(-1) + + if mask.any(): + loss = nn.functional.mse_loss(pred[mask], targets[mask]) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + total_loss += loss.item() * mask.sum().item() + n_samples += mask.sum().item() + + pbar.set_postfix({"loss": total_loss / max(n_samples, 1)}) + + avg_loss = total_loss / max(n_samples, 1) + return {"loss": avg_loss, "n_samples": n_samples} + + +@torch.no_grad() +def validate_delivery( + model: nn.Module, + loader: DataLoader, + device: torch.device, +) -> Dict[str, float]: + """验证(仅 delivery 任务)""" + model.eval() + total_loss = 0.0 + n_samples = 0 + all_preds = [] + all_targets = [] + + for batch in loader: + smiles = batch["smiles"] + tabular = {k: v.to(device) for k, v in batch["tabular"].items()} + targets = batch["targets"]["delivery"].to(device) + mask = batch["mask"]["delivery"].to(device) + + pred = model.forward_delivery(smiles, tabular).squeeze(-1) + + if mask.any(): + loss = nn.functional.mse_loss(pred[mask], targets[mask]) + total_loss += loss.item() * mask.sum().item() + n_samples += mask.sum().item() + all_preds.extend(pred[mask].cpu().numpy().tolist()) + all_targets.extend(targets[mask].cpu().numpy().tolist()) + + avg_loss = total_loss / max(n_samples, 1) + + # 计算额外指标 + metrics = {"loss": avg_loss, "n_samples": n_samples} + if len(all_preds) > 0: + all_preds = np.array(all_preds) + all_targets = np.array(all_targets) + metrics["rmse"] = float(np.sqrt(mean_squared_error(all_targets, all_preds))) + metrics["r2"] = float(r2_score(all_targets, all_preds)) + + return metrics + + +def train_fold( + fold_idx: int, + train_loader: DataLoader, + val_loader: DataLoader, + model: nn.Module, + device: torch.device, + output_dir: Path, + lr: float = 1e-4, + weight_decay: float = 1e-5, + epochs: int = 50, + patience: int = 10, + config: Optional[Dict] = None, +) -> Dict: + """训练单个 fold""" + logger.info(f"\n{'='*60}") + logger.info(f"Training Fold {fold_idx}") + logger.info(f"{'='*60}") + + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + early_stopping = EarlyStopping(patience=patience) + + best_val_loss = float("inf") + best_state = None + history = [] + + for epoch in range(epochs): + train_metrics = train_epoch_delivery(model, train_loader, optimizer, device, epoch) + val_metrics = validate_delivery(model, val_loader, device) + + current_lr = optimizer.param_groups[0]["lr"] + logger.info( + f"Fold {fold_idx} Epoch {epoch+1}/{epochs} | " + f"Train Loss: {train_metrics['loss']:.4f} | " + f"Val Loss: {val_metrics['loss']:.4f} | " + f"Val RMSE: {val_metrics.get('rmse', 0):.4f} | " + f"Val R²: {val_metrics.get('r2', 0):.4f} | " + f"LR: {current_lr:.2e}" + ) + + history.append({ + "epoch": epoch + 1, + "train_loss": train_metrics["loss"], + "val_loss": val_metrics["loss"], + "val_rmse": val_metrics.get("rmse", 0), + "val_r2": val_metrics.get("r2", 0), + "lr": current_lr, + }) + + scheduler.step(val_metrics["loss"]) + + if val_metrics["loss"] < best_val_loss: + best_val_loss = val_metrics["loss"] + best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + logger.info(f" -> New best val_loss: {best_val_loss:.4f}") + + if early_stopping(val_metrics["loss"]): + logger.info(f"Early stopping at epoch {epoch+1}") + break + + # 保存最佳模型 + fold_output_dir = output_dir / f"fold_{fold_idx}" + fold_output_dir.mkdir(parents=True, exist_ok=True) + + checkpoint_path = fold_output_dir / "model.pt" + torch.save({ + "model_state_dict": best_state, + "config": config, + "best_val_loss": best_val_loss, + "fold_idx": fold_idx, + }, checkpoint_path) + logger.success(f"Saved fold {fold_idx} model to {checkpoint_path}") + + # 保存训练历史 + history_path = fold_output_dir / "history.json" + with open(history_path, "w") as f: + json.dump(history, f, indent=2) + + return { + "fold_idx": fold_idx, + "best_val_loss": best_val_loss, + "best_val_rmse": history[-1]["val_rmse"] if history else 0, + "best_val_r2": history[-1]["val_r2"] if history else 0, + "epochs_trained": len(history), + } + + +def create_model( + d_model: int = 256, + num_heads: int = 8, + n_attn_layers: int = 4, + fusion_strategy: str = "attention", + head_hidden_dim: int = 128, + dropout: float = 0.1, + use_mpnn: bool = False, + mpnn_ensemble_paths: Optional[List[str]] = None, + mpnn_device: str = "cpu", +) -> nn.Module: + """创建模型实例""" + if use_mpnn: + return LNPModel( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + mpnn_ensemble_paths=mpnn_ensemble_paths, + mpnn_device=mpnn_device, + ) + else: + return LNPModelWithoutMPNN( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + ) + + +@app.command() +def main( + data_dir: Path = PROCESSED_DATA_DIR / "cv", + output_dir: Path = MODELS_DIR / "pretrain_cv", + # 模型参数 + d_model: int = 256, + num_heads: int = 8, + n_attn_layers: int = 4, + fusion_strategy: str = "attention", + head_hidden_dim: int = 128, + dropout: float = 0.1, + # MPNN 参数 + use_mpnn: bool = False, + mpnn_checkpoint: Optional[str] = None, + mpnn_ensemble_paths: Optional[str] = None, + mpnn_device: str = "cpu", + # 训练参数 + batch_size: int = 64, + lr: float = 1e-4, + weight_decay: float = 1e-5, + epochs: int = 50, + patience: int = 10, + # 设备 + device: str = "cuda" if torch.cuda.is_available() else "cpu", +): + """ + 基于 5-fold Cross-Validation 预训练 LNP 模型(仅 delivery 任务)。 + + 每个 fold 单独训练一个模型,保存到 output_dir/fold_x/model.pt。 + 使用 --use-mpnn 启用 MPNN encoder。 + """ + logger.info(f"Using device: {device}") + device = torch.device(device) + + # 解析 MPNN 参数 + mpnn_paths = None + if use_mpnn: + if mpnn_ensemble_paths: + mpnn_paths = mpnn_ensemble_paths.split(",") + logger.info(f"Using provided MPNN ensemble paths: {len(mpnn_paths)} models") + elif mpnn_checkpoint: + mpnn_paths = [mpnn_checkpoint] + logger.info(f"Using single MPNN checkpoint: {mpnn_checkpoint}") + else: + logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}") + mpnn_paths = find_mpnn_ensemble_paths() + logger.info(f"Found {len(mpnn_paths)} MPNN models") + + # 查找所有 fold 目录 + fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")]) + + if not fold_dirs: + logger.error(f"No fold_* directories found in {data_dir}") + logger.info("Please run 'make data_cv' first to process CV data.") + raise typer.Exit(1) + + logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}") + + output_dir.mkdir(parents=True, exist_ok=True) + + # 模型配置 + config = { + "d_model": d_model, + "num_heads": num_heads, + "n_attn_layers": n_attn_layers, + "fusion_strategy": fusion_strategy, + "head_hidden_dim": head_hidden_dim, + "dropout": dropout, + "use_mpnn": use_mpnn, + "mpnn_ensemble_paths": mpnn_paths, + "lr": lr, + "weight_decay": weight_decay, + "batch_size": batch_size, + "epochs": epochs, + "patience": patience, + } + + # 保存配置 + config_path = output_dir / "config.json" + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + logger.info(f"Saved config to {config_path}") + + # 训练每个 fold + fold_results = [] + all_smiles = set() + + # 先收集所有 SMILES 用于 cache warmup + for fold_dir in fold_dirs: + for split in ["train", "valid"]: + df = pd.read_parquet(fold_dir / f"{split}.parquet") + all_smiles.update(df["smiles"].tolist()) + + for fold_dir in fold_dirs: + fold_idx = int(fold_dir.name.split("_")[1]) + + # 加载数据 + train_df = pd.read_parquet(fold_dir / "train.parquet") + val_df = pd.read_parquet(fold_dir / "valid.parquet") + + logger.info(f"\nFold {fold_idx}: train={len(train_df)}, val={len(val_df)}") + + # 创建 Dataset 和 DataLoader + train_dataset = ExternalDeliveryDataset(train_df) + val_dataset = ExternalDeliveryDataset(val_df) + + train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn + ) + val_loader = DataLoader( + val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn + ) + + # 创建新模型(每个 fold 独立初始化) + model = create_model( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + use_mpnn=use_mpnn, + mpnn_ensemble_paths=mpnn_paths, + mpnn_device=device.type, + ) + model = model.to(device) + + # 第一个 fold 时做 cache warmup + if fold_idx == 0: + warmup_cache(model, list(all_smiles), batch_size=256) + + logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + + # 训练 + result = train_fold( + fold_idx=fold_idx, + train_loader=train_loader, + val_loader=val_loader, + model=model, + device=device, + output_dir=output_dir, + lr=lr, + weight_decay=weight_decay, + epochs=epochs, + patience=patience, + config=config, + ) + fold_results.append(result) + + # 汇总结果 + logger.info("\n" + "=" * 60) + logger.info("CROSS-VALIDATION TRAINING COMPLETE") + logger.info("=" * 60) + + val_losses = [r["best_val_loss"] for r in fold_results] + val_rmses = [r["best_val_rmse"] for r in fold_results] + val_r2s = [r["best_val_r2"] for r in fold_results] + + logger.info(f"\n[Per-Fold Results]") + for r in fold_results: + logger.info( + f" Fold {r['fold_idx']}: " + f"Val Loss={r['best_val_loss']:.4f}, " + f"RMSE={r['best_val_rmse']:.4f}, " + f"R²={r['best_val_r2']:.4f}, " + f"Epochs={r['epochs_trained']}" + ) + + logger.info(f"\n[Summary Statistics]") + logger.info(f" Val Loss: {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}") + logger.info(f" Val RMSE: {np.mean(val_rmses):.4f} ± {np.std(val_rmses):.4f}") + logger.info(f" Val R²: {np.mean(val_r2s):.4f} ± {np.std(val_r2s):.4f}") + + # 保存 CV 结果 + cv_results = { + "fold_results": fold_results, + "summary": { + "val_loss_mean": float(np.mean(val_losses)), + "val_loss_std": float(np.std(val_losses)), + "val_rmse_mean": float(np.mean(val_rmses)), + "val_rmse_std": float(np.std(val_rmses)), + "val_r2_mean": float(np.mean(val_r2s)), + "val_r2_std": float(np.std(val_r2s)), + }, + "config": config, + } + + results_path = output_dir / "cv_results.json" + with open(results_path, "w") as f: + json.dump(cv_results, f, indent=2) + logger.success(f"Saved CV results to {results_path}") + + +@app.command() +def test( + data_dir: Path = PROCESSED_DATA_DIR / "cv", + model_dir: Path = MODELS_DIR / "pretrain_cv", + output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json", + batch_size: int = 64, + device: str = "cuda" if torch.cuda.is_available() else "cpu", +): + """ + 在测试集上评估 CV 预训练模型。 + + 使用每个 fold 的模型在对应的测试集上评估。 + """ + logger.info(f"Using device: {device}") + device = torch.device(device) + + # 查找所有 fold 目录 + fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")]) + + if not fold_dirs: + logger.error(f"No fold_* directories found in {data_dir}") + raise typer.Exit(1) + + logger.info(f"Found {len(fold_dirs)} folds") + + fold_results = [] + all_preds = [] + all_targets = [] + + for fold_dir in fold_dirs: + fold_idx = int(fold_dir.name.split("_")[1]) + model_path = model_dir / f"fold_{fold_idx}" / "model.pt" + test_path = fold_dir / "test.parquet" + + if not model_path.exists(): + logger.warning(f"Fold {fold_idx}: model not found at {model_path}, skipping") + continue + + if not test_path.exists(): + logger.warning(f"Fold {fold_idx}: test data not found at {test_path}, skipping") + continue + + # 加载模型 + checkpoint = torch.load(model_path, map_location=device) + config = checkpoint["config"] + + use_mpnn = config.get("use_mpnn", False) + mpnn_paths = config.get("mpnn_ensemble_paths") + + if use_mpnn and not mpnn_paths: + mpnn_paths = find_mpnn_ensemble_paths() + + model = create_model( + d_model=config["d_model"], + num_heads=config["num_heads"], + n_attn_layers=config["n_attn_layers"], + fusion_strategy=config["fusion_strategy"], + head_hidden_dim=config["head_hidden_dim"], + dropout=config["dropout"], + use_mpnn=use_mpnn, + mpnn_ensemble_paths=mpnn_paths, + mpnn_device=device.type, + ) + model.load_state_dict(checkpoint["model_state_dict"]) + model = model.to(device) + model.eval() + + # 加载测试数据 + test_df = pd.read_parquet(test_path) + test_dataset = ExternalDeliveryDataset(test_df) + test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn + ) + + # 评估 + fold_preds = [] + fold_targets = [] + + with torch.no_grad(): + for batch in test_loader: + smiles = batch["smiles"] + tabular = {k: v.to(device) for k, v in batch["tabular"].items()} + targets = batch["targets"]["delivery"].to(device) + mask = batch["mask"]["delivery"].to(device) + + pred = model.forward_delivery(smiles, tabular).squeeze(-1) + + if mask.any(): + fold_preds.extend(pred[mask].cpu().numpy().tolist()) + fold_targets.extend(targets[mask].cpu().numpy().tolist()) + + # 计算 fold 指标 + fold_preds = np.array(fold_preds) + fold_targets = np.array(fold_targets) + + mse = float(mean_squared_error(fold_targets, fold_preds)) + rmse = float(np.sqrt(mse)) + r2 = float(r2_score(fold_targets, fold_preds)) + mae = float(np.mean(np.abs(fold_targets - fold_preds))) + corr = float(np.corrcoef(fold_targets, fold_preds)[0, 1]) + + fold_results.append({ + "fold_idx": fold_idx, + "n_samples": len(fold_preds), + "mse": mse, + "rmse": rmse, + "mae": mae, + "r2": r2, + "correlation": corr, + }) + + all_preds.extend(fold_preds.tolist()) + all_targets.extend(fold_targets.tolist()) + + logger.info( + f"Fold {fold_idx}: n={len(fold_preds)}, " + f"RMSE={rmse:.4f}, R²={r2:.4f}, MAE={mae:.4f}, Corr={corr:.4f}" + ) + + # 计算整体指标 + all_preds = np.array(all_preds) + all_targets = np.array(all_targets) + + overall_mse = float(mean_squared_error(all_targets, all_preds)) + overall_rmse = float(np.sqrt(overall_mse)) + overall_r2 = float(r2_score(all_targets, all_preds)) + overall_mae = float(np.mean(np.abs(all_targets - all_preds))) + overall_corr = float(np.corrcoef(all_targets, all_preds)[0, 1]) + + # 汇总统计 + rmses = [r["rmse"] for r in fold_results] + r2s = [r["r2"] for r in fold_results] + + logger.info("\n" + "=" * 60) + logger.info("CV TEST EVALUATION RESULTS") + logger.info("=" * 60) + + logger.info(f"\n[Summary Statistics (across {len(fold_results)} folds)]") + logger.info(f" RMSE: {np.mean(rmses):.4f} ± {np.std(rmses):.4f}") + logger.info(f" R²: {np.mean(r2s):.4f} ± {np.std(r2s):.4f}") + + logger.info(f"\n[Overall (all {len(all_preds)} samples pooled)]") + logger.info(f" RMSE: {overall_rmse:.4f}") + logger.info(f" R²: {overall_r2:.4f}") + logger.info(f" MAE: {overall_mae:.4f}") + logger.info(f" Correlation: {overall_corr:.4f}") + + # 保存结果 + results = { + "fold_results": fold_results, + "summary_stats": { + "rmse_mean": float(np.mean(rmses)), + "rmse_std": float(np.std(rmses)), + "r2_mean": float(np.mean(r2s)), + "r2_std": float(np.std(r2s)), + }, + "overall": { + "n_samples": len(all_preds), + "mse": overall_mse, + "rmse": overall_rmse, + "mae": overall_mae, + "r2": overall_r2, + "correlation": overall_corr, + }, + } + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + logger.success(f"Saved test results to {output_path}") + + +if __name__ == "__main__": + app() + diff --git a/models/history.json b/models/history.json index 312a5ba..2d0f9c5 100644 --- a/models/history.json +++ b/models/history.json @@ -1,2042 +1,614 @@ { "train": [ { - "loss": 25.9716534614563, - "loss_size": 20.24160861968994, - "loss_pdi": 1.3472331166267395, - "loss_ee": 1.1013503819704056, - "loss_delivery": 1.4049430154263973, - "loss_biodist": 1.2294384688138962, - "loss_toxic": 0.6470795646309853 - }, - { - "loss": 23.77423596382141, - "loss_size": 18.47769021987915, - "loss_pdi": 1.2455639243125916, - "loss_ee": 1.0423222184181213, - "loss_delivery": 1.296642705798149, - "loss_biodist": 1.16373211145401, - "loss_toxic": 0.5482855997979641 - }, - { - "loss": 21.890457153320312, - "loss_size": 16.907308220863342, - "loss_pdi": 1.1504559814929962, - "loss_ee": 0.9951949268579483, - "loss_delivery": 1.2649681903421879, - "loss_biodist": 1.0959244892001152, - "loss_toxic": 0.4766053706407547 - }, - { - "loss": 19.999611377716064, - "loss_size": 15.311449527740479, - "loss_pdi": 1.0705267116427422, - "loss_ee": 0.9475761577486992, - "loss_delivery": 1.2161596529185772, - "loss_biodist": 1.040858842432499, - "loss_toxic": 0.41304082050919533 - }, - { - "loss": 18.511160850524902, - "loss_size": 14.151495695114136, - "loss_pdi": 0.997661679983139, - "loss_ee": 0.9168402180075645, - "loss_delivery": 1.0985701605677605, - "loss_biodist": 0.9873060286045074, - "loss_toxic": 0.35928767547011375 - }, - { - "loss": 16.843316555023193, - "loss_size": 12.564021348953247, - "loss_pdi": 0.9332254827022552, - "loss_ee": 0.8977891877293587, - "loss_delivery": 1.1698411088436842, - "loss_biodist": 0.9497506842017174, - "loss_toxic": 0.3286888003349304 - }, - { - "loss": 15.443416118621826, - "loss_size": 11.327888250350952, - "loss_pdi": 0.8805835843086243, - "loss_ee": 0.8734613284468651, - "loss_delivery": 1.1609287671744823, - "loss_biodist": 0.8936299160122871, - "loss_toxic": 0.30692412704229355 - }, - { - "loss": 14.08805501461029, - "loss_size": 10.007415056228638, - "loss_pdi": 0.8400777131319046, - "loss_ee": 0.8545695245265961, - "loss_delivery": 1.2719034925103188, - "loss_biodist": 0.8452144712209702, - "loss_toxic": 0.26887485571205616 - }, - { - "loss": 12.719679474830627, - "loss_size": 8.955224633216858, - "loss_pdi": 0.7929210364818573, - "loss_ee": 0.8508650660514832, - "loss_delivery": 1.0332960579544306, - "loss_biodist": 0.8195833638310432, - "loss_toxic": 0.2677891217172146 - }, - { - "loss": 11.719684600830078, - "loss_size": 8.014833927154541, - "loss_pdi": 0.7591684088110924, - "loss_ee": 0.8181160017848015, - "loss_delivery": 1.0873776115477085, - "loss_biodist": 0.7861169949173927, - "loss_toxic": 0.25407182052731514 - }, - { - "loss": 10.390779733657837, - "loss_size": 6.543184518814087, - "loss_pdi": 0.7359360381960869, - "loss_ee": 0.8175727277994156, - "loss_delivery": 1.3261936996132135, - "loss_biodist": 0.74173803627491, - "loss_toxic": 0.22615496441721916 - }, - { - "loss": 9.279924273490906, - "loss_size": 5.779706597328186, - "loss_pdi": 0.7089737802743912, - "loss_ee": 0.8029346913099289, - "loss_delivery": 1.0412847138941288, - "loss_biodist": 0.7369969710707664, - "loss_toxic": 0.21002776641398668 - }, - { - "loss": 8.56308901309967, - "loss_size": 4.867785036563873, - "loss_pdi": 0.6987971663475037, - "loss_ee": 0.7856745272874832, - "loss_delivery": 1.3176121786236763, - "loss_biodist": 0.694796048104763, - "loss_toxic": 0.19842391926795244 - }, - { - "loss": 7.576077699661255, - "loss_size": 4.09323313832283, - "loss_pdi": 0.6696358025074005, - "loss_ee": 0.7626904547214508, - "loss_delivery": 1.1467845663428307, - "loss_biodist": 0.6884731277823448, - "loss_toxic": 0.2152608297765255 - }, - { - "loss": 6.771063804626465, - "loss_size": 3.47174733877182, - "loss_pdi": 0.6470936611294746, - "loss_ee": 0.7558285966515541, - "loss_delivery": 1.0831638108938932, - "loss_biodist": 0.6334695816040039, - "loss_toxic": 0.17976099345833063 - }, - { - "loss": 5.920511543750763, - "loss_size": 2.759486570954323, - "loss_pdi": 0.6359409764409065, - "loss_ee": 0.7437839284539223, - "loss_delivery": 0.9403424672782421, - "loss_biodist": 0.6445382535457611, - "loss_toxic": 0.1964192632585764 - }, - { - "loss": 5.423723220825195, - "loss_size": 2.331495225429535, - "loss_pdi": 0.6351363807916641, - "loss_ee": 0.7376566380262375, - "loss_delivery": 0.9581479858607054, - "loss_biodist": 0.5950389876961708, - "loss_toxic": 0.16624799743294716 - }, - { - "loss": 5.267421782016754, - "loss_size": 1.8453679084777832, - "loss_pdi": 0.6130843386054039, - "loss_ee": 0.7249858379364014, - "loss_delivery": 1.3575894925743341, - "loss_biodist": 0.5606166198849678, - "loss_toxic": 0.16577763762325048 - }, - { - "loss": 4.479124903678894, - "loss_size": 1.4466428607702255, - "loss_pdi": 0.5825827047228813, - "loss_ee": 0.7108604460954666, - "loss_delivery": 1.0510947555303574, - "loss_biodist": 0.5334530249238014, - "loss_toxic": 0.15449106134474277 - }, - { - "loss": 4.296796411275864, - "loss_size": 1.391601286828518, - "loss_pdi": 0.5907623171806335, - "loss_ee": 0.7128616869449615, - "loss_delivery": 0.9416903126984835, - "loss_biodist": 0.5074036121368408, - "loss_toxic": 0.15247714705765247 - }, - { - "loss": 4.6071557104587555, - "loss_size": 1.3102504722774029, - "loss_pdi": 0.5903710909187794, - "loss_ee": 0.6987340599298477, - "loss_delivery": 1.3774405717849731, - "loss_biodist": 0.49313804879784584, - "loss_toxic": 0.13722141925245523 - }, - { - "loss": 3.9111519753932953, - "loss_size": 1.058611437678337, - "loss_pdi": 0.5584814138710499, - "loss_ee": 0.6937372237443924, - "loss_delivery": 1.0242730379104614, - "loss_biodist": 0.446026936173439, - "loss_toxic": 0.13002192694693804 - }, - { - "loss": 3.6934784650802612, - "loss_size": 0.9869462251663208, - "loss_pdi": 0.5499657578766346, - "loss_ee": 0.6769986003637314, - "loss_delivery": 0.932591987773776, - "loss_biodist": 0.434326708316803, - "loss_toxic": 0.11264921864494681 - }, - { - "loss": 3.609632909297943, - "loss_size": 0.9397555366158485, - "loss_pdi": 0.5484890341758728, - "loss_ee": 0.6419399157166481, - "loss_delivery": 0.9404868837445974, - "loss_biodist": 0.4236420728266239, - "loss_toxic": 0.11531943175941706 - }, - { - "loss": 3.476574957370758, - "loss_size": 0.8621037006378174, - "loss_pdi": 0.5306216962635517, - "loss_ee": 0.6655778735876083, - "loss_delivery": 0.9486417435109615, - "loss_biodist": 0.3750348165631294, - "loss_toxic": 0.09459509095177054 - }, - { - "loss": 3.2607316374778748, - "loss_size": 0.764001876115799, - "loss_pdi": 0.5195549540221691, - "loss_ee": 0.6295172236859798, - "loss_delivery": 0.8859440544620156, - "loss_biodist": 0.37055982276797295, - "loss_toxic": 0.09115382377058268 - }, - { - "loss": 3.3650560081005096, - "loss_size": 0.6877520456910133, - "loss_pdi": 0.5163812413811684, - "loss_ee": 0.6349174529314041, - "loss_delivery": 1.0614477675408125, - "loss_biodist": 0.3767742030322552, - "loss_toxic": 0.08778328960761428 - }, - { - "loss": 3.1449066549539566, - "loss_size": 0.7033557221293449, - "loss_pdi": 0.5112948007881641, - "loss_ee": 0.6182029843330383, - "loss_delivery": 0.8769390396773815, - "loss_biodist": 0.35101368278265, - "loss_toxic": 0.08410049765370786 - }, - { - "loss": 3.198059171438217, - "loss_size": 0.738056268543005, - "loss_pdi": 0.5025036409497261, - "loss_ee": 0.6464762836694717, - "loss_delivery": 0.8959850911051035, - "loss_biodist": 0.336670720949769, - "loss_toxic": 0.07836716645397246 - }, - { - "loss": 3.0968509018421173, - "loss_size": 0.6075208149850368, - "loss_pdi": 0.49481675028800964, - "loss_ee": 0.6283772960305214, - "loss_delivery": 0.9538011457771063, - "loss_biodist": 0.33442937210202217, - "loss_toxic": 0.0779055068269372 - }, - { - "loss": 3.0471641421318054, - "loss_size": 0.5784552656114101, - "loss_pdi": 0.5156531557440758, - "loss_ee": 0.6174600012600422, - "loss_delivery": 0.9347921572625637, - "loss_biodist": 0.31574787199497223, - "loss_toxic": 0.0850557005032897 - }, - { - "loss": 3.074565827846527, - "loss_size": 0.6493570990860462, - "loss_pdi": 0.5111677534878254, - "loss_ee": 0.6072432287037373, - "loss_delivery": 0.9208926521241665, - "loss_biodist": 0.30962972342967987, - "loss_toxic": 0.07627540593966842 - }, - { - "loss": 3.1479561775922775, - "loss_size": 0.7691093385219574, - "loss_pdi": 0.5043143332004547, - "loss_ee": 0.6039227098226547, - "loss_delivery": 0.9247475061565638, - "loss_biodist": 0.2875036410987377, - "loss_toxic": 0.05835856357589364 - }, - { - "loss": 3.2086785435676575, - "loss_size": 0.6118294671177864, - "loss_pdi": 0.5001381486654282, - "loss_ee": 0.6337357684969902, - "loss_delivery": 1.0925203282386065, - "loss_biodist": 0.30640872195363045, - "loss_toxic": 0.06404614564962685 - }, - { - "loss": 2.903267800807953, - "loss_size": 0.6084516011178493, - "loss_pdi": 0.5031622871756554, - "loss_ee": 0.6048430800437927, - "loss_delivery": 0.8152581034228206, - "loss_biodist": 0.3017992302775383, - "loss_toxic": 0.06975363730452955 - }, - { - "loss": 2.7227470725774765, - "loss_size": 0.47578552551567554, - "loss_pdi": 0.48987653106451035, - "loss_ee": 0.5831519588828087, - "loss_delivery": 0.8346315119415522, - "loss_biodist": 0.2752541806548834, - "loss_toxic": 0.06404732668306679 - }, - { - "loss": 2.9256694465875626, - "loss_size": 0.5421494208276272, - "loss_pdi": 0.4676680266857147, - "loss_ee": 0.5828219950199127, - "loss_delivery": 1.007380524650216, - "loss_biodist": 0.25513165071606636, - "loss_toxic": 0.0705178261268884 - }, - { - "loss": 2.776222825050354, - "loss_size": 0.560955211520195, - "loss_pdi": 0.5120605081319809, - "loss_ee": 0.5816002264618874, - "loss_delivery": 0.7777703925967216, - "loss_biodist": 0.2862227726727724, - "loss_toxic": 0.05761362751945853 - }, - { - "loss": 2.903283640742302, - "loss_size": 0.598876278847456, - "loss_pdi": 0.48118938133120537, - "loss_ee": 0.6039806753396988, - "loss_delivery": 0.8762990292161703, - "loss_biodist": 0.26536008156836033, - "loss_toxic": 0.07757818652316928 - }, - { - "loss": 2.8232152611017227, - "loss_size": 0.4981044437736273, - "loss_pdi": 0.5083095543086529, - "loss_ee": 0.5776484534144402, - "loss_delivery": 0.9345780871808529, - "loss_biodist": 0.2536482270807028, - "loss_toxic": 0.0509264167631045 - }, - { - "loss": 2.719878375530243, - "loss_size": 0.5266173202544451, - "loss_pdi": 0.4750731997191906, - "loss_ee": 0.5943448096513748, - "loss_delivery": 0.8238696120679379, - "loss_biodist": 0.24152903258800507, - "loss_toxic": 0.05844438471831381 - }, - { - "loss": 2.669362887740135, - "loss_size": 0.41083680652081966, - "loss_pdi": 0.47290581837296486, - "loss_ee": 0.5690933167934418, - "loss_delivery": 0.8954417379572988, - "loss_biodist": 0.26001916266977787, - "loss_toxic": 0.061066087102517486 - }, - { - "loss": 2.7296886146068573, - "loss_size": 0.40873375721275806, - "loss_pdi": 0.5100873447954655, - "loss_ee": 0.5963915809988976, - "loss_delivery": 0.8862678501754999, - "loss_biodist": 0.2582801654934883, - "loss_toxic": 0.06992789637297392 - }, - { - "loss": 2.6801713705062866, - "loss_size": 0.4124011751264334, - "loss_pdi": 0.4926849640905857, - "loss_ee": 0.5796162374317646, - "loss_delivery": 0.8792719375342131, - "loss_biodist": 0.2574870977550745, - "loss_toxic": 0.05870996415615082 - }, - { - "loss": 2.5923274010419846, - "loss_size": 0.43998449202626944, - "loss_pdi": 0.4787449426949024, - "loss_ee": 0.5642235241830349, - "loss_delivery": 0.807167736813426, - "loss_biodist": 0.24508678168058395, - "loss_toxic": 0.05711989430710673 - }, - { - "loss": 2.706703007221222, - "loss_size": 0.48683931678533554, - "loss_pdi": 0.46516377106308937, - "loss_ee": 0.5705448277294636, - "loss_delivery": 0.8720798492431641, - "loss_biodist": 0.23792196623981, - "loss_toxic": 0.07415329676587135 - }, - { - "loss": 2.5786157697439194, - "loss_size": 0.4590853825211525, - "loss_pdi": 0.4839537553489208, - "loss_ee": 0.5635706633329391, - "loss_delivery": 0.7978988699615002, - "loss_biodist": 0.23025565408170223, - "loss_toxic": 0.04385152133181691 - }, - { - "loss": 2.6873574256896973, - "loss_size": 0.4268788732588291, - "loss_pdi": 0.4642701633274555, - "loss_ee": 0.5801760032773018, - "loss_delivery": 0.913929826579988, - "loss_biodist": 0.23901514150202274, - "loss_toxic": 0.06308741425164044 - }, - { - "loss": 2.685258597135544, - "loss_size": 0.5264910068362951, - "loss_pdi": 0.4676165319979191, - "loss_ee": 0.5818453542888165, - "loss_delivery": 0.8150268085300922, - "loss_biodist": 0.22798488661646843, - "loss_toxic": 0.06629409588640556 - }, - { - "loss": 2.52669258415699, - "loss_size": 0.41255798749625683, - "loss_pdi": 0.4563356712460518, - "loss_ee": 0.5704724453389645, - "loss_delivery": 0.8344292026013136, - "loss_biodist": 0.20569896139204502, - "loss_toxic": 0.04719832225237042 - }, - { - "loss": 2.5535005182027817, - "loss_size": 0.381014097481966, - "loss_pdi": 0.4735785350203514, - "loss_ee": 0.5684775337576866, - "loss_delivery": 0.8227455485612154, - "loss_biodist": 0.23827529326081276, - "loss_toxic": 0.06940951908472925 - }, - { - "loss": 2.5589767545461655, - "loss_size": 0.44938809610903263, - "loss_pdi": 0.45997676625847816, - "loss_ee": 0.5761362612247467, - "loss_delivery": 0.80368347838521, - "loss_biodist": 0.21627510525286198, - "loss_toxic": 0.05351710086688399 - }, - { - "loss": 2.4209747910499573, - "loss_size": 0.4209021870046854, - "loss_pdi": 0.4443623758852482, - "loss_ee": 0.5494826473295689, - "loss_delivery": 0.726479004137218, - "loss_biodist": 0.21563152223825455, - "loss_toxic": 0.06411702185869217 - }, - { - "loss": 2.402697578072548, - "loss_size": 0.3580847531557083, - "loss_pdi": 0.4626186229288578, - "loss_ee": 0.5826950781047344, - "loss_delivery": 0.7483663521707058, - "loss_biodist": 0.20891179516911507, - "loss_toxic": 0.04202097177039832 - }, - { - "loss": 2.402700573205948, - "loss_size": 0.42229970917105675, - "loss_pdi": 0.4451366700232029, - "loss_ee": 0.566058874130249, - "loss_delivery": 0.7094440292567015, - "loss_biodist": 0.21921667829155922, - "loss_toxic": 0.04054457793245092 - }, - { - "loss": 2.447406530380249, - "loss_size": 0.39806674513965845, - "loss_pdi": 0.44775110855698586, - "loss_ee": 0.5492517799139023, - "loss_delivery": 0.7928667366504669, - "loss_biodist": 0.21111027523875237, - "loss_toxic": 0.04835996555630118 - }, - { - "loss": 2.4445116221904755, - "loss_size": 0.3879490252584219, - "loss_pdi": 0.4956270530819893, - "loss_ee": 0.5405614376068115, - "loss_delivery": 0.7789672082290053, - "loss_biodist": 0.20555796474218369, - "loss_toxic": 0.03584893117658794 - }, - { - "loss": 2.3815398663282394, - "loss_size": 0.40303290262818336, - "loss_pdi": 0.44058041274547577, - "loss_ee": 0.5845668762922287, - "loss_delivery": 0.6737456526607275, - "loss_biodist": 0.2069079726934433, - "loss_toxic": 0.07270607736427337 - }, - { - "loss": 2.4094582945108414, - "loss_size": 0.3885020185261965, - "loss_pdi": 0.46220706030726433, - "loss_ee": 0.5269934982061386, - "loss_delivery": 0.7990537015721202, - "loss_biodist": 0.19428402185440063, - "loss_toxic": 0.03841806924901903 - }, - { - "loss": 2.4859633445739746, - "loss_size": 0.41715021431446075, - "loss_pdi": 0.43104546144604683, - "loss_ee": 0.5929308645427227, - "loss_delivery": 0.7805162407457829, - "loss_biodist": 0.21075740829110146, - "loss_toxic": 0.05356312752701342 - }, - { - "loss": 2.3863512128591537, - "loss_size": 0.36496103554964066, - "loss_pdi": 0.4448041245341301, - "loss_ee": 0.5620445422828197, - "loss_delivery": 0.7598662171512842, - "loss_biodist": 0.21108629740774632, - "loss_toxic": 0.04358899069484323 - }, - { - "loss": 2.540374130010605, - "loss_size": 0.421149879693985, - "loss_pdi": 0.4305218234658241, - "loss_ee": 0.5474253967404366, - "loss_delivery": 0.9024681374430656, - "loss_biodist": 0.2017455119639635, - "loss_toxic": 0.037063447292894125 - }, - { - "loss": 2.256133273243904, - "loss_size": 0.35868640802800655, - "loss_pdi": 0.4526713415980339, - "loss_ee": 0.5558537244796753, - "loss_delivery": 0.6482365503907204, - "loss_biodist": 0.19766092114150524, - "loss_toxic": 0.04302433942211792 - }, - { - "loss": 2.4118791967630386, - "loss_size": 0.3626519478857517, - "loss_pdi": 0.439793910831213, - "loss_ee": 0.5639018379151821, - "loss_delivery": 0.7862588986754417, - "loss_biodist": 0.2167180608958006, - "loss_toxic": 0.042554557556286454 - }, - { - "loss": 2.421144977211952, - "loss_size": 0.3997096996754408, - "loss_pdi": 0.4268031716346741, - "loss_ee": 0.54468197748065, - "loss_delivery": 0.7987409122288227, - "loss_biodist": 0.1979309469461441, - "loss_toxic": 0.053278335835784674 - }, - { - "loss": 2.484893947839737, - "loss_size": 0.40917484275996685, - "loss_pdi": 0.4523083493113518, - "loss_ee": 0.5565394945442677, - "loss_delivery": 0.7811942044645548, - "loss_biodist": 0.21532400511205196, - "loss_toxic": 0.07035311090294272 - }, - { - "loss": 2.508310005068779, - "loss_size": 0.392028434202075, - "loss_pdi": 0.46563537418842316, - "loss_ee": 0.5610726661980152, - "loss_delivery": 0.868115178309381, - "loss_biodist": 0.19330977648496628, - "loss_toxic": 0.028148552868515253 - }, - { - "loss": 2.2823522984981537, - "loss_size": 0.3725826870650053, - "loss_pdi": 0.4436277002096176, - "loss_ee": 0.5445637330412865, - "loss_delivery": 0.6825304059311748, - "loss_biodist": 0.20507806539535522, - "loss_toxic": 0.03396970289759338 - }, - { - "loss": 2.340677961707115, - "loss_size": 0.3237522132694721, - "loss_pdi": 0.4429183080792427, - "loss_ee": 0.5430723577737808, - "loss_delivery": 0.7789825117215514, - "loss_biodist": 0.1932805050164461, - "loss_toxic": 0.05867206456605345 - }, - { - "loss": 2.43579663336277, - "loss_size": 0.3825421128422022, - "loss_pdi": 0.46084684878587723, - "loss_ee": 0.5356369465589523, - "loss_delivery": 0.8061059219762683, - "loss_biodist": 0.18804593943059444, - "loss_toxic": 0.06261878390796483 - }, - { - "loss": 2.375033915042877, - "loss_size": 0.37097565829753876, - "loss_pdi": 0.45299821346998215, - "loss_ee": 0.5483061745762825, - "loss_delivery": 0.7634187545627356, - "loss_biodist": 0.19628103263676167, - "loss_toxic": 0.043054113164544106 - }, - { - "loss": 2.375073567032814, - "loss_size": 0.39385137148201466, - "loss_pdi": 0.43302949890494347, - "loss_ee": 0.5369623377919197, - "loss_delivery": 0.7776657920330763, - "loss_biodist": 0.1930943038314581, - "loss_toxic": 0.04047024482861161 - }, - { - "loss": 2.498472973704338, - "loss_size": 0.457756033167243, - "loss_pdi": 0.48180752620100975, - "loss_ee": 0.5444387346506119, - "loss_delivery": 0.7720571514219046, - "loss_biodist": 0.21357632335275412, - "loss_toxic": 0.028837116580689326 - }, - { - "loss": 2.251427784562111, - "loss_size": 0.31997708044946194, - "loss_pdi": 0.4557594656944275, - "loss_ee": 0.5722625702619553, - "loss_delivery": 0.6752464389428496, - "loss_biodist": 0.1944030299782753, - "loss_toxic": 0.033779169199988246 - }, - { - "loss": 2.3404635787010193, - "loss_size": 0.39547222293913364, - "loss_pdi": 0.48117559030652046, - "loss_ee": 0.5512082874774933, - "loss_delivery": 0.6834059152752161, - "loss_biodist": 0.18757404759526253, - "loss_toxic": 0.04162753582932055 - }, - { - "loss": 2.332062780857086, - "loss_size": 0.40330377221107483, - "loss_pdi": 0.45610927417874336, - "loss_ee": 0.5598709620535374, - "loss_delivery": 0.6991077661514282, - "loss_biodist": 0.1802004612982273, - "loss_toxic": 0.03347053553443402 - }, - { - "loss": 2.5532881915569305, - "loss_size": 0.3296027425676584, - "loss_pdi": 0.4313250593841076, - "loss_ee": 0.5530117489397526, - "loss_delivery": 1.0097075831145048, - "loss_biodist": 0.19308063574135303, - "loss_toxic": 0.036560436128638685 - }, - { - "loss": 2.3191350996494293, - "loss_size": 0.36579276248812675, - "loss_pdi": 0.4509471468627453, - "loss_ee": 0.5503131933510303, - "loss_delivery": 0.7146945651620626, - "loss_biodist": 0.1930693220347166, - "loss_toxic": 0.044318083906546235 - }, - { - "loss": 2.329015627503395, - "loss_size": 0.3625293876975775, - "loss_pdi": 0.447160255163908, - "loss_ee": 0.5637771524488926, - "loss_delivery": 0.7219901196658611, - "loss_biodist": 0.182708989828825, - "loss_toxic": 0.05084967764560133 - }, - { - "loss": 2.2387402057647705, - "loss_size": 0.37715523317456245, - "loss_pdi": 0.443161316215992, - "loss_ee": 0.532851655036211, - "loss_delivery": 0.637946292757988, - "loss_biodist": 0.19288873113691807, - "loss_toxic": 0.05473693599924445 - }, - { - "loss": 2.3392106890678406, - "loss_size": 0.37016443349421024, - "loss_pdi": 0.43577099218964577, - "loss_ee": 0.5809498429298401, - "loss_delivery": 0.7330322470515966, - "loss_biodist": 0.1861476842314005, - "loss_toxic": 0.03314550075447187 - }, - { - "loss": 2.381509318947792, - "loss_size": 0.45705906488001347, - "loss_pdi": 0.4459034390747547, - "loss_ee": 0.5371449440717697, - "loss_delivery": 0.7081572283059359, - "loss_biodist": 0.1915800031274557, - "loss_toxic": 0.04166470328345895 - }, - { - "loss": 2.2895610630512238, - "loss_size": 0.30787856690585613, - "loss_pdi": 0.43192387744784355, - "loss_ee": 0.5282911062240601, - "loss_delivery": 0.8125729402527213, - "loss_biodist": 0.17857834417372942, - "loss_toxic": 0.03031624120194465 - }, - { - "loss": 2.342302069067955, - "loss_size": 0.35507766902446747, - "loss_pdi": 0.4288671351969242, - "loss_ee": 0.5287903435528278, - "loss_delivery": 0.799028629437089, - "loss_biodist": 0.19781222939491272, - "loss_toxic": 0.03272609505802393 - }, - { - "loss": 2.2220213413238525, - "loss_size": 0.3336242912337184, - "loss_pdi": 0.4279674366116524, - "loss_ee": 0.5243929699063301, - "loss_delivery": 0.7046115137636662, - "loss_biodist": 0.18190770782530308, - "loss_toxic": 0.04951742372941226 - }, - { - "loss": 2.2642442733049393, - "loss_size": 0.3745464440435171, - "loss_pdi": 0.4190495200455189, - "loss_ee": 0.5312219671905041, - "loss_delivery": 0.7121039722114801, - "loss_biodist": 0.18861200101673603, - "loss_toxic": 0.03871038107899949 - }, - { - "loss": 2.2668907940387726, - "loss_size": 0.3055088045075536, - "loss_pdi": 0.44202207773923874, - "loss_ee": 0.558069571852684, - "loss_delivery": 0.753497414290905, - "loss_biodist": 0.18022325448691845, - "loss_toxic": 0.027569634490646422 - }, - { - "loss": 2.265577092766762, - "loss_size": 0.3527289964258671, - "loss_pdi": 0.4210794195532799, - "loss_ee": 0.5633058845996857, - "loss_delivery": 0.7165728658437729, - "loss_biodist": 0.18325274251401424, - "loss_toxic": 0.028637277253437787 - }, - { - "loss": 2.3678945302963257, - "loss_size": 0.3269524369388819, - "loss_pdi": 0.4376198649406433, - "loss_ee": 0.5548702776432037, - "loss_delivery": 0.8310786969959736, - "loss_biodist": 0.1766184400767088, - "loss_toxic": 0.040754887741059065 - }, - { - "loss": 2.376565784215927, - "loss_size": 0.3146391473710537, - "loss_pdi": 0.4377659671008587, - "loss_ee": 0.5472971461713314, - "loss_delivery": 0.827107597142458, - "loss_biodist": 0.193230664357543, - "loss_toxic": 0.05652518745046109 - }, - { - "loss": 2.3293388187885284, - "loss_size": 0.313828706741333, - "loss_pdi": 0.4662150889635086, - "loss_ee": 0.5857372991740704, - "loss_delivery": 0.7384721748530865, - "loss_biodist": 0.19804997369647026, - "loss_toxic": 0.02703562140231952 - }, - { - "loss": 2.2240894734859467, - "loss_size": 0.32816210202872753, - "loss_pdi": 0.4228389263153076, - "loss_ee": 0.5168202854692936, - "loss_delivery": 0.7309106634929776, - "loss_biodist": 0.17415916360914707, - "loss_toxic": 0.05119828786700964 - }, - { - "loss": 2.3003778904676437, - "loss_size": 0.3464476577937603, - "loss_pdi": 0.4299623481929302, - "loss_ee": 0.5327660068869591, - "loss_delivery": 0.7682454977184534, - "loss_biodist": 0.17687828838825226, - "loss_toxic": 0.04607809009030461 - }, - { - "loss": 2.4176031351089478, - "loss_size": 0.33616673201322556, - "loss_pdi": 0.4266280457377434, - "loss_ee": 0.539965070784092, - "loss_delivery": 0.8959930576384068, - "loss_biodist": 0.1849408494308591, - "loss_toxic": 0.033909388119354844 - }, - { - "loss": 2.220036804676056, - "loss_size": 0.3079969398677349, - "loss_pdi": 0.4420403353869915, - "loss_ee": 0.5316371433436871, - "loss_delivery": 0.7277526371181011, - "loss_biodist": 0.18354396149516106, - "loss_toxic": 0.027065691770985723 - }, - { - "loss": 2.568151220679283, - "loss_size": 0.3081513475626707, - "loss_pdi": 0.4065406657755375, - "loss_ee": 0.5534945167601109, - "loss_delivery": 1.0929491445422173, - "loss_biodist": 0.17446389142423868, - "loss_toxic": 0.032551744021475315 - }, - { - "loss": 2.2266090363264084, - "loss_size": 0.30730851739645004, - "loss_pdi": 0.4372030571103096, - "loss_ee": 0.5150774084031582, - "loss_delivery": 0.7617569454014301, - "loss_biodist": 0.1675817985087633, - "loss_toxic": 0.03768134908750653 + "loss": 20.654017567634583, + "loss_size": 14.923280715942383, + "loss_pdi": 1.3630856722593307, + "loss_ee": 1.0538489520549774, + "loss_delivery": 1.3152384273707867, + "loss_biodist": 1.2500686794519424, + "loss_toxic": 0.7484960630536079 + }, + { + "loss": 8.597736835479736, + "loss_size": 3.7376724034547806, + "loss_pdi": 1.1567248702049255, + "loss_ee": 0.959331214427948, + "loss_delivery": 1.0784161668270826, + "loss_biodist": 1.1304775178432465, + "loss_toxic": 0.5351144559681416 + }, + { + "loss": 5.316226750612259, + "loss_size": 0.4655807036906481, + "loss_pdi": 0.890813983976841, + "loss_ee": 0.9265956059098244, + "loss_delivery": 1.6688463129103184, + "loss_biodist": 1.0243803411722183, + "loss_toxic": 0.3400097191333771 + }, + { + "loss": 4.236876010894775, + "loss_size": 0.2576048579066992, + "loss_pdi": 0.7199259623885155, + "loss_ee": 0.8882468864321709, + "loss_delivery": 1.1865669898688793, + "loss_biodist": 0.9087597280740738, + "loss_toxic": 0.2757715005427599 + }, + { + "loss": 4.008229672908783, + "loss_size": 0.2828900143504143, + "loss_pdi": 0.6668110117316246, + "loss_ee": 0.8487689569592476, + "loss_delivery": 1.210050592198968, + "loss_biodist": 0.7741772681474686, + "loss_toxic": 0.22553179040551186 + }, + { + "loss": 3.4132821559906006, + "loss_size": 0.28209344763308764, + "loss_pdi": 0.6196199581027031, + "loss_ee": 0.7953124567866325, + "loss_delivery": 0.886458033695817, + "loss_biodist": 0.6592964082956314, + "loss_toxic": 0.17050191573798656 + }, + { + "loss": 3.1554177701473236, + "loss_size": 0.2222631135955453, + "loss_pdi": 0.5906740687787533, + "loss_ee": 0.7774071097373962, + "loss_delivery": 0.8972328305244446, + "loss_biodist": 0.5283852256834507, + "loss_toxic": 0.13945529703050852 + }, + { + "loss": 3.008445233106613, + "loss_size": 0.2672971077263355, + "loss_pdi": 0.5615717135369778, + "loss_ee": 0.7414998263120651, + "loss_delivery": 0.8624669294804335, + "loss_biodist": 0.4568806663155556, + "loss_toxic": 0.11872897483408451 + }, + { + "loss": 2.7489685714244843, + "loss_size": 0.2941299509257078, + "loss_pdi": 0.5309903435409069, + "loss_ee": 0.6796365603804588, + "loss_delivery": 0.7240526992827654, + "loss_biodist": 0.4194994159042835, + "loss_toxic": 0.10065969126299024 + }, + { + "loss": 2.994155704975128, + "loss_size": 0.23459727689623833, + "loss_pdi": 0.5059699863195419, + "loss_ee": 0.6792710162699223, + "loss_delivery": 1.123040821403265, + "loss_biodist": 0.35866043344140053, + "loss_toxic": 0.09261624282225966 + }, + { + "loss": 2.526204437017441, + "loss_size": 0.2318584816530347, + "loss_pdi": 0.5105148889124393, + "loss_ee": 0.6241044960916042, + "loss_delivery": 0.7530241832137108, + "loss_biodist": 0.3161732591688633, + "loss_toxic": 0.09052912541665137 + }, + { + "loss": 2.4040733128786087, + "loss_size": 0.25336731784045696, + "loss_pdi": 0.4627600871026516, + "loss_ee": 0.5792314857244492, + "loss_delivery": 0.7273973049595952, + "loss_biodist": 0.2851091828197241, + "loss_toxic": 0.09620802523568273 + }, + { + "loss": 2.277775838971138, + "loss_size": 0.23108398634940386, + "loss_pdi": 0.47473394870758057, + "loss_ee": 0.5694540254771709, + "loss_delivery": 0.6473548840731382, + "loss_biodist": 0.2682333216071129, + "loss_toxic": 0.08691561967134476 + }, + { + "loss": 2.1118388026952744, + "loss_size": 0.19272647704929113, + "loss_pdi": 0.4506654106080532, + "loss_ee": 0.5451454631984234, + "loss_delivery": 0.6076795756816864, + "loss_biodist": 0.24520794302225113, + "loss_toxic": 0.07041396829299629 + }, + { + "loss": 2.3364796936511993, + "loss_size": 0.202391910366714, + "loss_pdi": 0.4175494909286499, + "loss_ee": 0.5630357973277569, + "loss_delivery": 0.8727079201489687, + "loss_biodist": 0.2182436492294073, + "loss_toxic": 0.0625509019009769 + }, + { + "loss": 2.121677279472351, + "loss_size": 0.19022844545543194, + "loss_pdi": 0.4154125601053238, + "loss_ee": 0.5368015170097351, + "loss_delivery": 0.6966933384537697, + "loss_biodist": 0.21289130486547947, + "loss_toxic": 0.06965014082379639 + }, + { + "loss": 2.0318090319633484, + "loss_size": 0.21694329474121332, + "loss_pdi": 0.4048071689903736, + "loss_ee": 0.515247642993927, + "loss_delivery": 0.6368602626025677, + "loss_biodist": 0.20683623664081097, + "loss_toxic": 0.05111439700704068 + }, + { + "loss": 1.9118669629096985, + "loss_size": 0.19875546218827367, + "loss_pdi": 0.3908081129193306, + "loss_ee": 0.5096109956502914, + "loss_delivery": 0.5701626418158412, + "loss_biodist": 0.1929602213203907, + "loss_toxic": 0.049569567665457726 + }, + { + "loss": 1.839075192809105, + "loss_size": 0.18530041445046663, + "loss_pdi": 0.3970159702003002, + "loss_ee": 0.46845678985118866, + "loss_delivery": 0.5625405590981245, + "loss_biodist": 0.1886431071907282, + "loss_toxic": 0.037118358886800706 + }, + { + "loss": 1.8664235323667526, + "loss_size": 0.18123449478298426, + "loss_pdi": 0.40222033485770226, + "loss_ee": 0.4552966132760048, + "loss_delivery": 0.6082880217581987, + "loss_biodist": 0.18262280710041523, + "loss_toxic": 0.036761312861926854 + }, + { + "loss": 1.716733992099762, + "loss_size": 0.18156101368367672, + "loss_pdi": 0.3858310990035534, + "loss_ee": 0.4562831334769726, + "loss_delivery": 0.49082688614726067, + "loss_biodist": 0.1689397692680359, + "loss_toxic": 0.033292022766545415 + }, + { + "loss": 1.6656295657157898, + "loss_size": 0.15507809165865183, + "loss_pdi": 0.38257285952568054, + "loss_ee": 0.43476733937859535, + "loss_delivery": 0.4871903257444501, + "loss_biodist": 0.1687515852972865, + "loss_toxic": 0.0372693199897185 + }, + { + "loss": 1.655124008655548, + "loss_size": 0.16677627619355917, + "loss_pdi": 0.39169614762067795, + "loss_ee": 0.4506494253873825, + "loss_delivery": 0.4437606744468212, + "loss_biodist": 0.17045665439218283, + "loss_toxic": 0.03178481827490032 + }, + { + "loss": 1.6233856305480003, + "loss_size": 0.1664448268711567, + "loss_pdi": 0.38898324593901634, + "loss_ee": 0.42385032027959824, + "loss_delivery": 0.4545885343104601, + "loss_biodist": 0.1660583931952715, + "loss_toxic": 0.02346029842738062 + }, + { + "loss": 1.5908179432153702, + "loss_size": 0.17132249101996422, + "loss_pdi": 0.34435465931892395, + "loss_ee": 0.4377247728407383, + "loss_delivery": 0.4506380669772625, + "loss_biodist": 0.14942463673651218, + "loss_toxic": 0.03735327522736043 + }, + { + "loss": 1.6518655568361282, + "loss_size": 0.23489851504564285, + "loss_pdi": 0.36253464221954346, + "loss_ee": 0.4624558910727501, + "loss_delivery": 0.40858984366059303, + "loss_biodist": 0.1552291251718998, + "loss_toxic": 0.02815757622011006 + }, + { + "loss": 1.5824971497058868, + "loss_size": 0.17537706904113293, + "loss_pdi": 0.33741601184010506, + "loss_ee": 0.4335809648036957, + "loss_delivery": 0.4561629304662347, + "loss_biodist": 0.15684568136930466, + "loss_toxic": 0.0231145127909258 + }, + { + "loss": 1.4595527052879333, + "loss_size": 0.13256958965212107, + "loss_pdi": 0.33822081610560417, + "loss_ee": 0.4079545773565769, + "loss_delivery": 0.40720680449157953, + "loss_biodist": 0.1479589343070984, + "loss_toxic": 0.025641972781158984 + }, + { + "loss": 1.5028350502252579, + "loss_size": 0.1535406243056059, + "loss_pdi": 0.3422395624220371, + "loss_ee": 0.45890774950385094, + "loss_delivery": 0.3780768755823374, + "loss_biodist": 0.15004045329988003, + "loss_toxic": 0.020029835402965546 } ], "val": [ { - "loss": 23.08255672454834, - "loss_size": 18.165200233459473, - "loss_pdi": 1.38570237159729, - "loss_ee": 1.0819937586784363, - "loss_delivery": 0.5172057598829269, - "loss_biodist": 1.2911686301231384, - "loss_toxic": 0.6412850320339203, - "acc_pdi": 0.15, - "acc_ee": 0.4, - "acc_toxic": 0.8717948717948718 + "loss": 10.581983089447021, + "loss_size": 6.111684322357178, + "loss_pdi": 1.2079024910926819, + "loss_ee": 1.0208574831485748, + "loss_delivery": 0.4347192794084549, + "loss_biodist": 1.244754672050476, + "loss_toxic": 0.5620639622211456, + "acc_pdi": 0.6166666666666667, + "acc_ee": 0.5833333333333334, + "acc_toxic": 0.9743589743589743 }, { - "loss": 21.95634937286377, - "loss_size": 17.241751670837402, - "loss_pdi": 1.321295142173767, - "loss_ee": 1.065436601638794, - "loss_delivery": 0.49157825112342834, - "loss_biodist": 1.2544752359390259, - "loss_toxic": 0.5818120241165161, - "acc_pdi": 0.4, - "acc_ee": 0.5, - "acc_toxic": 0.9487179487179487 + "loss": 4.2768449783325195, + "loss_size": 0.47457408905029297, + "loss_pdi": 0.9688999950885773, + "loss_ee": 0.9677374958992004, + "loss_delivery": 0.41415125131607056, + "loss_biodist": 1.1090667247772217, + "loss_toxic": 0.342415452003479, + "acc_pdi": 0.75, + "acc_ee": 0.5833333333333334, + "acc_toxic": 0.9743589743589743 }, { - "loss": 20.918076515197754, - "loss_size": 16.435887336730957, - "loss_pdi": 1.2529736161231995, - "loss_ee": 1.0552538633346558, - "loss_delivery": 0.45487193763256073, - "loss_biodist": 1.2014788389205933, - "loss_toxic": 0.5176102221012115, + "loss": 3.4424057006835938, + "loss_size": 0.1261746659874916, + "loss_pdi": 0.7631869912147522, + "loss_ee": 0.9518796503543854, + "loss_delivery": 0.41135601699352264, + "loss_biodist": 1.00320166349411, + "loss_toxic": 0.18660662323236465, + "acc_pdi": 0.75, + "acc_ee": 0.5833333333333334, + "acc_toxic": 0.9743589743589743 + }, + { + "loss": 3.1414873600006104, + "loss_size": 0.07491225376725197, + "loss_pdi": 0.6709764003753662, + "loss_ee": 0.9438966512680054, + "loss_delivery": 0.41013580560684204, + "loss_biodist": 0.9091836810112, + "loss_toxic": 0.13238248974084854, + "acc_pdi": 0.75, + "acc_ee": 0.5833333333333334, + "acc_toxic": 0.9743589743589743 + }, + { + "loss": 2.958674430847168, + "loss_size": 0.10566488280892372, + "loss_pdi": 0.6293394267559052, + "loss_ee": 0.9313123822212219, + "loss_delivery": 0.3923226296901703, + "loss_biodist": 0.7916653156280518, + "loss_toxic": 0.10836982727050781, + "acc_pdi": 0.75, + "acc_ee": 0.6, + "acc_toxic": 0.9743589743589743 + }, + { + "loss": 2.7770891189575195, + "loss_size": 0.12206005305051804, + "loss_pdi": 0.6022174060344696, + "loss_ee": 0.9289207756519318, + "loss_delivery": 0.37663495540618896, + "loss_biodist": 0.6687739789485931, + "loss_toxic": 0.07848186790943146, + "acc_pdi": 0.75, + "acc_ee": 0.5333333333333333, + "acc_toxic": 0.9743589743589743 + }, + { + "loss": 2.7035106420516968, + "loss_size": 0.1292480230331421, + "loss_pdi": 0.5812408328056335, + "loss_ee": 0.9108770489692688, + "loss_delivery": 0.4427077919244766, + "loss_biodist": 0.5844394862651825, + "loss_toxic": 0.05499742925167084, + "acc_pdi": 0.75, + "acc_ee": 0.5666666666666667, + "acc_toxic": 1.0 + }, + { + "loss": 2.6182435750961304, + "loss_size": 0.12175065651535988, + "loss_pdi": 0.5671640783548355, + "loss_ee": 0.8962210118770599, + "loss_delivery": 0.45819681882858276, + "loss_biodist": 0.535559669137001, + "loss_toxic": 0.039351195096969604, + "acc_pdi": 0.75, + "acc_ee": 0.6333333333333333, + "acc_toxic": 1.0 + }, + { + "loss": 2.5277271270751953, + "loss_size": 0.15728875994682312, + "loss_pdi": 0.5866016000509262, + "loss_ee": 0.8671419620513916, + "loss_delivery": 0.40086667239665985, + "loss_biodist": 0.4854527860879898, + "loss_toxic": 0.030375237576663494, + "acc_pdi": 0.7166666666666667, + "acc_ee": 0.6166666666666667, + "acc_toxic": 1.0 + }, + { + "loss": 2.4565478563308716, + "loss_size": 0.11484631523489952, + "loss_pdi": 0.6012618541717529, + "loss_ee": 0.8616225123405457, + "loss_delivery": 0.4195407032966614, + "loss_biodist": 0.4309428632259369, + "loss_toxic": 0.028333663009107113, "acc_pdi": 0.7, - "acc_ee": 0.5666666666666667, - "acc_toxic": 0.9743589743589743 + "acc_ee": 0.65, + "acc_toxic": 1.0 }, { - "loss": 19.858214378356934, - "loss_size": 15.611548900604248, - "loss_pdi": 1.1820820569992065, - "loss_ee": 1.0452399253845215, - "loss_delivery": 0.4230894297361374, - "loss_biodist": 1.1472065448760986, - "loss_toxic": 0.4490478038787842, - "acc_pdi": 0.75, - "acc_ee": 0.5666666666666667, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 18.437131881713867, - "loss_size": 14.455862998962402, - "loss_pdi": 1.0957686305046082, - "loss_ee": 1.0236665606498718, - "loss_delivery": 0.39902180433273315, - "loss_biodist": 1.0887972712516785, - "loss_toxic": 0.3740147799253464, - "acc_pdi": 0.7666666666666667, - "acc_ee": 0.55, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 16.55983066558838, - "loss_size": 12.830113410949707, - "loss_pdi": 0.9735458791255951, - "loss_ee": 0.9793745875358582, - "loss_delivery": 0.4688366800546646, - "loss_biodist": 1.0075940787792206, - "loss_toxic": 0.3003671020269394, - "acc_pdi": 0.75, - "acc_ee": 0.5666666666666667, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 14.015560150146484, - "loss_size": 10.566654205322266, - "loss_pdi": 0.8625653684139252, - "loss_ee": 0.95384681224823, - "loss_delivery": 0.4795728325843811, - "loss_biodist": 0.9144233465194702, - "loss_toxic": 0.23849742859601974, - "acc_pdi": 0.75, - "acc_ee": 0.5833333333333334, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 11.799624919891357, - "loss_size": 8.538016319274902, - "loss_pdi": 0.787468433380127, - "loss_ee": 0.9436340928077698, - "loss_delivery": 0.4672486186027527, - "loss_biodist": 0.8618913292884827, - "loss_toxic": 0.2013658806681633, - "acc_pdi": 0.75, - "acc_ee": 0.5833333333333334, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 10.430795192718506, - "loss_size": 7.2916247844696045, - "loss_pdi": 0.7509226500988007, - "loss_ee": 0.9349042773246765, - "loss_delivery": 0.4556972235441208, - "loss_biodist": 0.8191786706447601, - "loss_toxic": 0.17846806347370148, - "acc_pdi": 0.75, + "loss": 2.514389753341675, + "loss_size": 0.1847861334681511, + "loss_pdi": 0.600894033908844, + "loss_ee": 0.8522049486637115, + "loss_delivery": 0.44791263341903687, + "loss_biodist": 0.4031175971031189, + "loss_toxic": 0.025474454276263714, + "acc_pdi": 0.7166666666666667, "acc_ee": 0.6166666666666667, - "acc_toxic": 0.9743589743589743 + "acc_toxic": 1.0 }, { - "loss": 9.274777889251709, - "loss_size": 6.23999810218811, - "loss_pdi": 0.7233669757843018, - "loss_ee": 0.9241549670696259, - "loss_delivery": 0.44265152513980865, - "loss_biodist": 0.7824479639530182, - "loss_toxic": 0.16215819492936134, - "acc_pdi": 0.75, - "acc_ee": 0.65, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 8.267348766326904, - "loss_size": 5.320166349411011, - "loss_pdi": 0.7005950510501862, - "loss_ee": 0.9186501204967499, - "loss_delivery": 0.42710645496845245, - "loss_biodist": 0.7505658268928528, - "loss_toxic": 0.15026485174894333, - "acc_pdi": 0.75, - "acc_ee": 0.65, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 7.32521915435791, - "loss_size": 4.463782072067261, - "loss_pdi": 0.6826708316802979, - "loss_ee": 0.9134883880615234, - "loss_delivery": 0.4078119993209839, - "loss_biodist": 0.7188615798950195, - "loss_toxic": 0.138604324311018, - "acc_pdi": 0.75, - "acc_ee": 0.65, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 6.4078004360198975, - "loss_size": 3.6203211545944214, - "loss_pdi": 0.6657015681266785, - "loss_ee": 0.9120738506317139, - "loss_delivery": 0.3918580412864685, - "loss_biodist": 0.6897956728935242, - "loss_toxic": 0.1280498132109642, - "acc_pdi": 0.75, - "acc_ee": 0.6333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 5.603572607040405, - "loss_size": 2.866372585296631, - "loss_pdi": 0.653519868850708, - "loss_ee": 0.9104083180427551, - "loss_delivery": 0.3852090388536453, - "loss_biodist": 0.6665074825286865, - "loss_toxic": 0.1215553842484951, - "acc_pdi": 0.7333333333333333, + "loss": 2.4860599040985107, + "loss_size": 0.12922396510839462, + "loss_pdi": 0.6249651610851288, + "loss_ee": 0.8718289136886597, + "loss_delivery": 0.4517715871334076, + "loss_biodist": 0.38561882078647614, + "loss_toxic": 0.022651473060250282, + "acc_pdi": 0.7166666666666667, "acc_ee": 0.6166666666666667, - "acc_toxic": 0.9743589743589743 + "acc_toxic": 1.0 }, { - "loss": 4.989350318908691, - "loss_size": 2.298922061920166, - "loss_pdi": 0.6447156071662903, - "loss_ee": 0.9073992371559143, - "loss_delivery": 0.38263915479183197, - "loss_biodist": 0.6420116424560547, - "loss_toxic": 0.11366239935159683, + "loss": 2.49249005317688, + "loss_size": 0.14513171464204788, + "loss_pdi": 0.5973644554615021, + "loss_ee": 0.8825545012950897, + "loss_delivery": 0.47353847324848175, + "loss_biodist": 0.3708433359861374, + "loss_toxic": 0.023057437501847744, "acc_pdi": 0.7333333333333333, "acc_ee": 0.6333333333333333, - "acc_toxic": 0.9743589743589743 + "acc_toxic": 1.0 }, { - "loss": 4.424848794937134, - "loss_size": 1.7787790298461914, - "loss_pdi": 0.6333552300930023, - "loss_ee": 0.9078906178474426, - "loss_delivery": 0.37971924245357513, - "loss_biodist": 0.6179626286029816, - "loss_toxic": 0.10714217647910118, + "loss": 2.4548466205596924, + "loss_size": 0.16008514910936356, + "loss_pdi": 0.5971579700708389, + "loss_ee": 0.8976304829120636, + "loss_delivery": 0.43868468701839447, + "loss_biodist": 0.34443435072898865, + "loss_toxic": 0.016853836365044117, + "acc_pdi": 0.7333333333333333, + "acc_ee": 0.55, + "acc_toxic": 1.0 + }, + { + "loss": 2.521517276763916, + "loss_size": 0.15012626349925995, + "loss_pdi": 0.638213574886322, + "loss_ee": 0.9129332005977631, + "loss_delivery": 0.47218939661979675, + "loss_biodist": 0.33256953954696655, + "loss_toxic": 0.015485215000808239, "acc_pdi": 0.7333333333333333, "acc_ee": 0.6333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 3.994532346725464, - "loss_size": 1.378051996231079, - "loss_pdi": 0.6262906491756439, - "loss_ee": 0.9124537408351898, - "loss_delivery": 0.37957052886486053, - "loss_biodist": 0.5964316725730896, - "loss_toxic": 0.1017337292432785, - "acc_pdi": 0.7333333333333333, - "acc_ee": 0.6, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 3.6140177249908447, - "loss_size": 1.0360195338726044, - "loss_pdi": 0.6164572238922119, - "loss_ee": 0.9151040613651276, - "loss_delivery": 0.37897755205631256, - "loss_biodist": 0.5732159912586212, - "loss_toxic": 0.09424328245222569, - "acc_pdi": 0.7333333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 3.328765392303467, - "loss_size": 0.7998744249343872, - "loss_pdi": 0.6098372936248779, - "loss_ee": 0.9119531512260437, - "loss_delivery": 0.3727646470069885, - "loss_biodist": 0.5473044216632843, - "loss_toxic": 0.08703150600194931, - "acc_pdi": 0.7333333333333333, - "acc_ee": 0.55, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 3.185804009437561, - "loss_size": 0.6846985220909119, - "loss_pdi": 0.6066886484622955, - "loss_ee": 0.9120742380619049, - "loss_delivery": 0.3714919835329056, - "loss_biodist": 0.5310576558113098, - "loss_toxic": 0.07979300618171692, - "acc_pdi": 0.7333333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 3.0729531049728394, - "loss_size": 0.6023903787136078, - "loss_pdi": 0.6019311547279358, - "loss_ee": 0.9207166433334351, - "loss_delivery": 0.3675245940685272, - "loss_biodist": 0.5082270801067352, - "loss_toxic": 0.07216321676969528, - "acc_pdi": 0.7333333333333333, - "acc_ee": 0.55, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.986412763595581, - "loss_size": 0.5425005555152893, - "loss_pdi": 0.5986292362213135, - "loss_ee": 0.9168869256973267, - "loss_delivery": 0.3692060112953186, - "loss_biodist": 0.4880067855119705, - "loss_toxic": 0.07118316926062107, - "acc_pdi": 0.75, - "acc_ee": 0.5333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.8937536478042603, - "loss_size": 0.48279696702957153, - "loss_pdi": 0.5968037247657776, - "loss_ee": 0.9193116426467896, - "loss_delivery": 0.35497498512268066, - "loss_biodist": 0.4733995348215103, - "loss_toxic": 0.06646681018173695, - "acc_pdi": 0.75, - "acc_ee": 0.5333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.8241339921951294, - "loss_size": 0.4351967126131058, - "loss_pdi": 0.5942685306072235, - "loss_ee": 0.9179154634475708, - "loss_delivery": 0.3557481914758682, - "loss_biodist": 0.45890866219997406, - "loss_toxic": 0.06209634803235531, - "acc_pdi": 0.75, - "acc_ee": 0.5166666666666667, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.775130271911621, - "loss_size": 0.41016826033592224, - "loss_pdi": 0.5879503339529037, - "loss_ee": 0.9228730499744415, - "loss_delivery": 0.35271845757961273, - "loss_biodist": 0.4440280497074127, - "loss_toxic": 0.05739211477339268, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.728317618370056, - "loss_size": 0.3769180178642273, - "loss_pdi": 0.585728108882904, - "loss_ee": 0.9276572167873383, - "loss_delivery": 0.35073578357696533, - "loss_biodist": 0.4340643882751465, - "loss_toxic": 0.05321396887302399, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.7271993160247803, - "loss_size": 0.3812417685985565, - "loss_pdi": 0.5841663330793381, - "loss_ee": 0.9325368106365204, - "loss_delivery": 0.34711775183677673, - "loss_biodist": 0.43170611560344696, - "loss_toxic": 0.050430385395884514, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.6966670751571655, - "loss_size": 0.3598712086677551, - "loss_pdi": 0.5865683704614639, - "loss_ee": 0.9341267049312592, - "loss_delivery": 0.34253913164138794, - "loss_biodist": 0.4231300801038742, - "loss_toxic": 0.050431531853973866, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.55, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.677440047264099, - "loss_size": 0.34720419347286224, - "loss_pdi": 0.5916665196418762, - "loss_ee": 0.941874772310257, - "loss_delivery": 0.34214016795158386, - "loss_biodist": 0.4104868620634079, - "loss_toxic": 0.04406765662133694, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5166666666666667, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.636168122291565, - "loss_size": 0.33124659955501556, - "loss_pdi": 0.5931780338287354, - "loss_ee": 0.9324212372303009, - "loss_delivery": 0.337383434176445, - "loss_biodist": 0.40123969316482544, - "loss_toxic": 0.040698954835534096, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.6184176206588745, - "loss_size": 0.32667799293994904, - "loss_pdi": 0.5926244109869003, - "loss_ee": 0.9347528517246246, - "loss_delivery": 0.3297251760959625, - "loss_biodist": 0.3973996490240097, - "loss_toxic": 0.037237657234072685, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.594933271408081, - "loss_size": 0.3120705187320709, - "loss_pdi": 0.5939976274967194, - "loss_ee": 0.9371457397937775, - "loss_delivery": 0.32641829550266266, - "loss_biodist": 0.39165879786014557, - "loss_toxic": 0.03364230878651142, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.5623388290405273, - "loss_size": 0.2766270190477371, - "loss_pdi": 0.5902731567621231, - "loss_ee": 0.9450692534446716, - "loss_delivery": 0.32445892691612244, - "loss_biodist": 0.39274075627326965, - "loss_toxic": 0.033169761300086975, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.55, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.5489457845687866, - "loss_size": 0.2664923220872879, - "loss_pdi": 0.5925990045070648, - "loss_ee": 0.9470328390598297, - "loss_delivery": 0.32114414870738983, - "loss_biodist": 0.38897469639778137, - "loss_toxic": 0.0327027402818203, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, "acc_toxic": 1.0 }, { - "loss": 2.5473875999450684, - "loss_size": 0.2693425267934799, - "loss_pdi": 0.5898241400718689, - "loss_ee": 0.9507412314414978, - "loss_delivery": 0.32274745404720306, - "loss_biodist": 0.3810035288333893, - "loss_toxic": 0.033728599548339844, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.5582650899887085, - "loss_size": 0.28376954793930054, - "loss_pdi": 0.5857751667499542, - "loss_ee": 0.9577462077140808, - "loss_delivery": 0.3179885745048523, - "loss_biodist": 0.38155072927474976, - "loss_toxic": 0.03143497183918953, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5, + "loss": 2.4589271545410156, + "loss_size": 0.12595879659056664, + "loss_pdi": 0.6382911801338196, + "loss_ee": 0.8725089728832245, + "loss_delivery": 0.47598887979984283, + "loss_biodist": 0.3318416327238083, + "loss_toxic": 0.014337773434817791, + "acc_pdi": 0.7166666666666667, + "acc_ee": 0.6333333333333333, "acc_toxic": 1.0 }, { - "loss": 2.5661860704421997, - "loss_size": 0.2933680862188339, - "loss_pdi": 0.5852044820785522, - "loss_ee": 0.9604993760585785, - "loss_delivery": 0.31451259553432465, - "loss_biodist": 0.3806009888648987, - "loss_toxic": 0.03200071305036545, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5166666666666667, - "acc_toxic": 0.9743589743589743 - }, - { - "loss": 2.541078805923462, - "loss_size": 0.2658995985984802, - "loss_pdi": 0.5860227644443512, - "loss_ee": 0.9658644497394562, - "loss_delivery": 0.3182491958141327, - "loss_biodist": 0.3728503882884979, - "loss_toxic": 0.03219226747751236, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, - "acc_toxic": 1.0 - }, - { - "loss": 2.5170637369155884, - "loss_size": 0.2407098039984703, - "loss_pdi": 0.5888298004865646, - "loss_ee": 0.9680635631084442, - "loss_delivery": 0.3207763731479645, - "loss_biodist": 0.3703601360321045, - "loss_toxic": 0.028324089013040066, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.55, - "acc_toxic": 1.0 - }, - { - "loss": 2.5068060159683228, - "loss_size": 0.23479244858026505, - "loss_pdi": 0.5894846469163895, - "loss_ee": 0.968824714422226, - "loss_delivery": 0.317771315574646, - "loss_biodist": 0.3679642081260681, - "loss_toxic": 0.027968653477728367, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.511378288269043, - "loss_size": 0.2382667362689972, - "loss_pdi": 0.5884567648172379, - "loss_ee": 0.9717344343662262, - "loss_delivery": 0.31294092535972595, - "loss_biodist": 0.36750735342502594, - "loss_toxic": 0.03247209172695875, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4968656301498413, - "loss_size": 0.22835177183151245, - "loss_pdi": 0.5894641727209091, - "loss_ee": 0.9705208837985992, - "loss_delivery": 0.3155461549758911, - "loss_biodist": 0.36456646025180817, - "loss_toxic": 0.02841610088944435, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4925342798233032, - "loss_size": 0.2224661186337471, - "loss_pdi": 0.5911581218242645, - "loss_ee": 0.9787282645702362, - "loss_delivery": 0.3117832839488983, - "loss_biodist": 0.3618563562631607, - "loss_toxic": 0.026542033068835735, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.519076108932495, - "loss_size": 0.23777800798416138, - "loss_pdi": 0.5882950127124786, - "loss_ee": 0.9824095964431763, - "loss_delivery": 0.3143518269062042, - "loss_biodist": 0.36608877778053284, - "loss_toxic": 0.0301527613773942, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.5395556688308716, - "loss_size": 0.24812977015972137, - "loss_pdi": 0.5908422470092773, - "loss_ee": 0.9888725280761719, - "loss_delivery": 0.31910496950149536, - "loss_biodist": 0.3644738346338272, - "loss_toxic": 0.028132320381700993, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.5324513912200928, - "loss_size": 0.23656832426786423, - "loss_pdi": 0.5896281898021698, - "loss_ee": 0.9954959154129028, - "loss_delivery": 0.3214045614004135, - "loss_biodist": 0.3617252707481384, - "loss_toxic": 0.027629065327346325, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5, - "acc_toxic": 1.0 - }, - { - "loss": 2.4930797815322876, - "loss_size": 0.20964767783880234, - "loss_pdi": 0.5881698727607727, - "loss_ee": 0.9954250454902649, - "loss_delivery": 0.32059434056282043, - "loss_biodist": 0.3531523495912552, - "loss_toxic": 0.026090470142662525, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.55, - "acc_toxic": 1.0 - }, - { - "loss": 2.4868186712265015, - "loss_size": 0.21395514160394669, - "loss_pdi": 0.5830750465393066, - "loss_ee": 0.9925644099712372, - "loss_delivery": 0.3193023353815079, - "loss_biodist": 0.3552027642726898, - "loss_toxic": 0.022718950174748898, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.55, - "acc_toxic": 1.0 - }, - { - "loss": 2.4778655767440796, - "loss_size": 0.20702877640724182, - "loss_pdi": 0.5822963416576385, - "loss_ee": 0.9994721114635468, - "loss_delivery": 0.31518860161304474, - "loss_biodist": 0.35139843821525574, - "loss_toxic": 0.022481259889900684, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, - "acc_toxic": 1.0 - }, - { - "loss": 2.467949628829956, - "loss_size": 0.20378626137971878, - "loss_pdi": 0.5813045352697372, - "loss_ee": 0.9965270459651947, - "loss_delivery": 0.3159128874540329, - "loss_biodist": 0.3484266996383667, - "loss_toxic": 0.021992099471390247, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5333333333333333, - "acc_toxic": 1.0 - }, - { - "loss": 2.4578932523727417, - "loss_size": 0.2058413401246071, - "loss_pdi": 0.5786841958761215, - "loss_ee": 0.9927798807621002, - "loss_delivery": 0.3089330494403839, - "loss_biodist": 0.34664781391620636, - "loss_toxic": 0.025006826035678387, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4626444578170776, - "loss_size": 0.20806843042373657, - "loss_pdi": 0.5785287618637085, - "loss_ee": 0.9961800277233124, - "loss_delivery": 0.3105767220258713, - "loss_biodist": 0.34737493097782135, - "loss_toxic": 0.021915599703788757, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4782203435897827, - "loss_size": 0.21168024837970734, - "loss_pdi": 0.5783185958862305, - "loss_ee": 0.9937601685523987, - "loss_delivery": 0.31742599606513977, - "loss_biodist": 0.35029137134552, - "loss_toxic": 0.02674387115985155, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.483683466911316, - "loss_size": 0.2182474434375763, - "loss_pdi": 0.5817886888980865, - "loss_ee": 0.9968518018722534, - "loss_delivery": 0.31819912791252136, - "loss_biodist": 0.34571440517902374, - "loss_toxic": 0.022882056422531605, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4820202589035034, - "loss_size": 0.22005712240934372, - "loss_pdi": 0.5797277837991714, - "loss_ee": 0.9988358616828918, - "loss_delivery": 0.3213921785354614, - "loss_biodist": 0.34191887080669403, - "loss_toxic": 0.02008841000497341, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.475839376449585, - "loss_size": 0.20475323498249054, - "loss_pdi": 0.5745992809534073, - "loss_ee": 1.009226769208908, - "loss_delivery": 0.3221261650323868, - "loss_biodist": 0.34656713902950287, - "loss_toxic": 0.018566792830824852, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.55, - "acc_toxic": 1.0 - }, - { - "loss": 2.47631299495697, - "loss_size": 0.21439334005117416, - "loss_pdi": 0.5734581053256989, - "loss_ee": 1.0064772963523865, - "loss_delivery": 0.3126187026500702, - "loss_biodist": 0.3460565060377121, - "loss_toxic": 0.023308915086090565, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4701257944107056, - "loss_size": 0.2064252272248268, - "loss_pdi": 0.5740563273429871, - "loss_ee": 1.0102430582046509, - "loss_delivery": 0.3147464245557785, - "loss_biodist": 0.3420388251543045, - "loss_toxic": 0.02261580526828766, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.4774882793426514, - "loss_size": 0.20556586235761642, - "loss_pdi": 0.5752835124731064, - "loss_ee": 1.0055716931819916, - "loss_delivery": 0.32358165085315704, - "loss_biodist": 0.3414890617132187, - "loss_toxic": 0.02599663846194744, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.470389246940613, - "loss_size": 0.20361963659524918, - "loss_pdi": 0.5763607174158096, - "loss_ee": 1.0087409019470215, - "loss_delivery": 0.3187643438577652, - "loss_biodist": 0.34229810535907745, - "loss_toxic": 0.02060555759817362, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4505800008773804, - "loss_size": 0.1872227042913437, - "loss_pdi": 0.5772093236446381, - "loss_ee": 1.0082023739814758, - "loss_delivery": 0.3152957409620285, - "loss_biodist": 0.3428477793931961, - "loss_toxic": 0.019802039489150047, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.55, - "acc_toxic": 1.0 - }, - { - "loss": 2.4552417993545532, - "loss_size": 0.18671875447034836, - "loss_pdi": 0.5765772759914398, - "loss_ee": 1.01038059592247, - "loss_delivery": 0.32385362684726715, - "loss_biodist": 0.3395358622074127, - "loss_toxic": 0.018175733741372824, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4580785036087036, - "loss_size": 0.19634269177913666, - "loss_pdi": 0.57661272585392, - "loss_ee": 1.006902813911438, - "loss_delivery": 0.31340549886226654, - "loss_biodist": 0.3417620062828064, - "loss_toxic": 0.023052792996168137, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4515216946601868, - "loss_size": 0.18773971498012543, - "loss_pdi": 0.5795569121837616, - "loss_ee": 1.0060312151908875, - "loss_delivery": 0.3152562975883484, - "loss_biodist": 0.3416850417852402, - "loss_toxic": 0.02125252317637205, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.449348211288452, - "loss_size": 0.18833737820386887, - "loss_pdi": 0.5763599127531052, - "loss_ee": 1.0090400576591492, - "loss_delivery": 0.3151460140943527, - "loss_biodist": 0.3424200266599655, - "loss_toxic": 0.018045054748654366, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.4551481008529663, - "loss_size": 0.18659686297178268, - "loss_pdi": 0.5755659937858582, - "loss_ee": 1.0128966569900513, - "loss_delivery": 0.3214796185493469, - "loss_biodist": 0.3400968313217163, - "loss_toxic": 0.0185121176764369, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.448776125907898, - "loss_size": 0.1820262148976326, - "loss_pdi": 0.5731459110975266, - "loss_ee": 1.0081259310245514, - "loss_delivery": 0.3281840831041336, - "loss_biodist": 0.3398120403289795, - "loss_toxic": 0.017481757327914238, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.448722243309021, - "loss_size": 0.18495317548513412, - "loss_pdi": 0.5750085860490799, - "loss_ee": 1.005994737148285, - "loss_delivery": 0.32928895950317383, - "loss_biodist": 0.33658044040203094, - "loss_toxic": 0.016896324697881937, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.443658947944641, - "loss_size": 0.18055323511362076, - "loss_pdi": 0.5750256329774857, - "loss_ee": 1.004571557044983, - "loss_delivery": 0.33198240399360657, - "loss_biodist": 0.33537501096725464, - "loss_toxic": 0.01615101331844926, - "acc_pdi": 0.7833333333333333, + "loss": 2.591088891029358, + "loss_size": 0.214198537170887, + "loss_pdi": 0.6199622452259064, + "loss_ee": 0.8776431977748871, + "loss_delivery": 0.5473216474056244, + "loss_biodist": 0.3174735903739929, + "loss_toxic": 0.014489701949059963, + "acc_pdi": 0.7166666666666667, "acc_ee": 0.6, "acc_toxic": 1.0 }, { - "loss": 2.4464157223701477, - "loss_size": 0.19530382752418518, - "loss_pdi": 0.5750800967216492, - "loss_ee": 1.0032556354999542, - "loss_delivery": 0.32189127802848816, - "loss_biodist": 0.3329106569290161, - "loss_toxic": 0.017974247690290213, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.6, - "acc_toxic": 1.0 - }, - { - "loss": 2.4568694829940796, - "loss_size": 0.19456836581230164, - "loss_pdi": 0.5787238329648972, - "loss_ee": 1.0072905719280243, - "loss_delivery": 0.32528989017009735, - "loss_biodist": 0.33407391607761383, - "loss_toxic": 0.016922770999372005, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.4530590772628784, - "loss_size": 0.19079380482435226, - "loss_pdi": 0.5766863524913788, - "loss_ee": 1.0125726163387299, - "loss_delivery": 0.3218092918395996, - "loss_biodist": 0.33407746255397797, - "loss_toxic": 0.01711944444105029, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.6, - "acc_toxic": 1.0 - }, - { - "loss": 2.4398056268692017, - "loss_size": 0.18868489563465118, - "loss_pdi": 0.5734779387712479, - "loss_ee": 1.0145431160926819, - "loss_delivery": 0.3121902197599411, - "loss_biodist": 0.3350667208433151, - "loss_toxic": 0.015842758119106293, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.43908154964447, - "loss_size": 0.17981701344251633, - "loss_pdi": 0.5743161141872406, - "loss_ee": 1.010342001914978, - "loss_delivery": 0.324299693107605, - "loss_biodist": 0.33383095264434814, - "loss_toxic": 0.01647581998258829, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.4145443439483643, - "loss_size": 0.1670462265610695, - "loss_pdi": 0.5734947621822357, - "loss_ee": 1.007502168416977, - "loss_delivery": 0.31635917723178864, - "loss_biodist": 0.3343554884195328, - "loss_toxic": 0.015786555130034685, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.6, - "acc_toxic": 1.0 - }, - { - "loss": 2.422114908695221, - "loss_size": 0.1707478016614914, - "loss_pdi": 0.5725287199020386, - "loss_ee": 1.007000058889389, - "loss_delivery": 0.3204977363348007, - "loss_biodist": 0.33519069850444794, - "loss_toxic": 0.016149940434843302, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.4240607023239136, - "loss_size": 0.178876593708992, - "loss_pdi": 0.5702834725379944, - "loss_ee": 1.0034980773925781, - "loss_delivery": 0.32336823642253876, - "loss_biodist": 0.3325919061899185, - "loss_toxic": 0.015442332718521357, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.6, - "acc_toxic": 1.0 - }, - { - "loss": 2.4249237179756165, - "loss_size": 0.18237261474132538, - "loss_pdi": 0.5713054090738297, - "loss_ee": 1.0015043318271637, - "loss_delivery": 0.3240789622068405, - "loss_biodist": 0.3306470960378647, - "loss_toxic": 0.0150153455324471, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.6, - "acc_toxic": 1.0 - }, - { - "loss": 2.4421173334121704, - "loss_size": 0.18255895376205444, - "loss_pdi": 0.5733452141284943, - "loss_ee": 1.0094293355941772, - "loss_delivery": 0.324260875582695, - "loss_biodist": 0.3330068737268448, - "loss_toxic": 0.01951604150235653, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.6, - "acc_toxic": 1.0 - }, - { - "loss": 2.428235411643982, - "loss_size": 0.18304699659347534, - "loss_pdi": 0.5728906840085983, - "loss_ee": 1.0043113827705383, - "loss_delivery": 0.3211822509765625, - "loss_biodist": 0.33053070306777954, - "loss_toxic": 0.01627331878989935, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.6, - "acc_toxic": 1.0 - }, - { - "loss": 2.423707902431488, - "loss_size": 0.17948835343122482, - "loss_pdi": 0.5713352113962173, - "loss_ee": 1.0125969350337982, - "loss_delivery": 0.31306323409080505, - "loss_biodist": 0.3301372826099396, - "loss_toxic": 0.017086807172745466, - "acc_pdi": 0.7833333333333333, + "loss": 2.514071464538574, + "loss_size": 0.10233941301703453, + "loss_pdi": 0.6488057374954224, + "loss_ee": 0.9082691967487335, + "loss_delivery": 0.5161854103207588, + "loss_biodist": 0.3266214579343796, + "loss_toxic": 0.011850347276777029, + "acc_pdi": 0.7166666666666667, "acc_ee": 0.6166666666666667, "acc_toxic": 1.0 }, { - "loss": 2.4089826345443726, - "loss_size": 0.17556024342775345, - "loss_pdi": 0.5705321878194809, - "loss_ee": 1.0072000622749329, - "loss_delivery": 0.30933138728141785, - "loss_biodist": 0.33006031811237335, - "loss_toxic": 0.01629862328991294, - "acc_pdi": 0.7833333333333333, + "loss": 2.546463966369629, + "loss_size": 0.18290647119283676, + "loss_pdi": 0.6429814398288727, + "loss_ee": 0.9410196840763092, + "loss_delivery": 0.4633557051420212, + "loss_biodist": 0.30668824911117554, + "loss_toxic": 0.009512441698461771, + "acc_pdi": 0.7333333333333333, + "acc_ee": 0.5666666666666667, + "acc_toxic": 1.0 + }, + { + "loss": 2.5984102487564087, + "loss_size": 0.18551968783140182, + "loss_pdi": 0.6433025300502777, + "loss_ee": 0.9714332818984985, + "loss_delivery": 0.4929337501525879, + "loss_biodist": 0.29487940669059753, + "loss_toxic": 0.010341563262045383, + "acc_pdi": 0.7333333333333333, + "acc_ee": 0.55, + "acc_toxic": 1.0 + }, + { + "loss": 2.6119550466537476, + "loss_size": 0.18234724551439285, + "loss_pdi": 0.647159218788147, + "loss_ee": 0.9437112212181091, + "loss_delivery": 0.5265502035617828, + "loss_biodist": 0.3019600808620453, + "loss_toxic": 0.010227079968899488, + "acc_pdi": 0.7, + "acc_ee": 0.6, + "acc_toxic": 1.0 + }, + { + "loss": 2.5319557189941406, + "loss_size": 0.1407381109893322, + "loss_pdi": 0.694286435842514, + "loss_ee": 0.9337777495384216, + "loss_delivery": 0.447235107421875, + "loss_biodist": 0.30738565325737, + "loss_toxic": 0.008532558102160692, + "acc_pdi": 0.7, + "acc_ee": 0.6333333333333333, + "acc_toxic": 1.0 + }, + { + "loss": 2.587972640991211, + "loss_size": 0.17371021956205368, + "loss_pdi": 0.691169947385788, + "loss_ee": 0.9433956742286682, + "loss_delivery": 0.4801773577928543, + "loss_biodist": 0.29147079586982727, + "loss_toxic": 0.008048820542171597, + "acc_pdi": 0.7166666666666667, + "acc_ee": 0.6333333333333333, + "acc_toxic": 1.0 + }, + { + "loss": 2.5775197744369507, + "loss_size": 0.15414823964238167, + "loss_pdi": 0.679006963968277, + "loss_ee": 0.9441277086734772, + "loss_delivery": 0.49791012704372406, + "loss_biodist": 0.29377779364585876, + "loss_toxic": 0.008548843208700418, + "acc_pdi": 0.7166666666666667, "acc_ee": 0.6166666666666667, "acc_toxic": 1.0 }, { - "loss": 2.4195178151130676, - "loss_size": 0.18361611664295197, - "loss_pdi": 0.5730448365211487, - "loss_ee": 1.002352625131607, - "loss_delivery": 0.30763909220695496, - "loss_biodist": 0.33182990550994873, - "loss_toxic": 0.021035287529230118, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.414033889770508, - "loss_size": 0.17641977220773697, - "loss_pdi": 0.5686511248350143, - "loss_ee": 1.0031031966209412, - "loss_delivery": 0.31490200757980347, - "loss_biodist": 0.3313785493373871, - "loss_toxic": 0.01957935281097889, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.4226320385932922, - "loss_size": 0.1758788526058197, - "loss_pdi": 0.5708015114068985, - "loss_ee": 1.0085403621196747, - "loss_delivery": 0.31501661241054535, - "loss_biodist": 0.3336498290300369, - "loss_toxic": 0.018744912929832935, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.415340304374695, - "loss_size": 0.1780794858932495, - "loss_pdi": 0.5729846954345703, - "loss_ee": 1.0020167827606201, - "loss_delivery": 0.313737228512764, - "loss_biodist": 0.3304787129163742, - "loss_toxic": 0.018043378833681345, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.4289786219596863, - "loss_size": 0.17449475824832916, - "loss_pdi": 0.5745786726474762, - "loss_ee": 1.0082752704620361, - "loss_delivery": 0.32470449805259705, - "loss_biodist": 0.3305872231721878, - "loss_toxic": 0.016338232439011335, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.437878131866455, - "loss_size": 0.19016961753368378, - "loss_pdi": 0.5718495547771454, - "loss_ee": 1.0078614950180054, - "loss_delivery": 0.32231585681438446, - "loss_biodist": 0.32886606454849243, - "loss_toxic": 0.016815478913486004, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.432196855545044, - "loss_size": 0.18712375313043594, - "loss_pdi": 0.5707677900791168, - "loss_ee": 1.011586219072342, - "loss_delivery": 0.32039742171764374, - "loss_biodist": 0.3272295147180557, - "loss_toxic": 0.015092198271304369, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.442120671272278, - "loss_size": 0.1902673915028572, - "loss_pdi": 0.5707383751869202, - "loss_ee": 1.0139081478118896, - "loss_delivery": 0.31597262620925903, - "loss_biodist": 0.33136095106601715, - "loss_toxic": 0.019873091019690037, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.426782965660095, - "loss_size": 0.18581604212522507, - "loss_pdi": 0.5696376860141754, - "loss_ee": 1.0093466937541962, - "loss_delivery": 0.3177868574857712, - "loss_biodist": 0.3294990360736847, - "loss_toxic": 0.014696634840220213, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.427416503429413, - "loss_size": 0.1857571080327034, - "loss_pdi": 0.5706104040145874, - "loss_ee": 1.0067935585975647, - "loss_delivery": 0.31886860728263855, - "loss_biodist": 0.33135633170604706, - "loss_toxic": 0.01403044443577528, - "acc_pdi": 0.7833333333333333, + "loss": 2.639680027961731, + "loss_size": 0.18575960397720337, + "loss_pdi": 0.6786225736141205, + "loss_ee": 0.9446882903575897, + "loss_delivery": 0.531510129570961, + "loss_biodist": 0.2905762940645218, + "loss_toxic": 0.008523145224899054, + "acc_pdi": 0.7, "acc_ee": 0.6, "acc_toxic": 1.0 }, { - "loss": 2.4175720810890198, - "loss_size": 0.17682582139968872, - "loss_pdi": 0.5699317306280136, - "loss_ee": 1.008631855249405, - "loss_delivery": 0.31863583624362946, - "loss_biodist": 0.3295620381832123, - "loss_toxic": 0.013984768651425838, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5833333333333334, - "acc_toxic": 1.0 - }, - { - "loss": 2.4234968423843384, - "loss_size": 0.17140312492847443, - "loss_pdi": 0.5711115300655365, - "loss_ee": 1.0081288516521454, - "loss_delivery": 0.3275977224111557, - "loss_biodist": 0.33037069439888, - "loss_toxic": 0.01488499902188778, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.5666666666666667, - "acc_toxic": 1.0 - }, - { - "loss": 2.4282666444778442, - "loss_size": 0.17177677899599075, - "loss_pdi": 0.573297917842865, - "loss_ee": 1.0125916004180908, - "loss_delivery": 0.327491819858551, - "loss_biodist": 0.3302319347858429, - "loss_toxic": 0.012876724824309349, - "acc_pdi": 0.7833333333333333, + "loss": 2.5382663011550903, + "loss_size": 0.15963290259242058, + "loss_pdi": 0.6552937924861908, + "loss_ee": 0.9279005825519562, + "loss_delivery": 0.4982753098011017, + "loss_biodist": 0.29052674770355225, + "loss_toxic": 0.006636953912675381, + "acc_pdi": 0.7166666666666667, "acc_ee": 0.6, "acc_toxic": 1.0 }, { - "loss": 2.4261932373046875, - "loss_size": 0.16778022795915604, - "loss_pdi": 0.5743443667888641, - "loss_ee": 1.0115399360656738, - "loss_delivery": 0.33071593940258026, - "loss_biodist": 0.32824188470840454, - "loss_toxic": 0.0135708749294281, - "acc_pdi": 0.7833333333333333, + "loss": 2.5613224506378174, + "loss_size": 0.16891378536820412, + "loss_pdi": 0.654723048210144, + "loss_ee": 0.918777197599411, + "loss_delivery": 0.5200476199388504, + "loss_biodist": 0.2925422638654709, + "loss_toxic": 0.0063183787278831005, + "acc_pdi": 0.7, "acc_ee": 0.6, "acc_toxic": 1.0 }, { - "loss": 2.4238767623901367, - "loss_size": 0.17259139567613602, - "loss_pdi": 0.5731173902750015, - "loss_ee": 1.0133283734321594, - "loss_delivery": 0.31513088941574097, - "loss_biodist": 0.32966548204421997, - "loss_toxic": 0.02004324784502387, - "acc_pdi": 0.7833333333333333, - "acc_ee": 0.55, + "loss": 2.605978012084961, + "loss_size": 0.17506848275661469, + "loss_pdi": 0.6693913340568542, + "loss_ee": 0.9241081774234772, + "loss_delivery": 0.5359244644641876, + "loss_biodist": 0.2946065813302994, + "loss_toxic": 0.0068789038341492414, + "acc_pdi": 0.7, + "acc_ee": 0.6166666666666667, + "acc_toxic": 1.0 + }, + { + "loss": 2.6228197813034058, + "loss_size": 0.1656828299164772, + "loss_pdi": 0.6787506341934204, + "loss_ee": 0.9407162368297577, + "loss_delivery": 0.5315479636192322, + "loss_biodist": 0.29851914942264557, + "loss_toxic": 0.007602998288348317, + "acc_pdi": 0.7, + "acc_ee": 0.6333333333333333, "acc_toxic": 1.0 } ] diff --git a/models/model.pt b/models/model.pt index 936a454..7023800 100644 Binary files a/models/model.pt and b/models/model.pt differ diff --git a/models/pretrain_cv/config.json b/models/pretrain_cv/config.json new file mode 100644 index 0000000..0d92440 --- /dev/null +++ b/models/pretrain_cv/config.json @@ -0,0 +1,21 @@ +{ + "d_model": 256, + "num_heads": 8, + "n_attn_layers": 4, + "fusion_strategy": "attention", + "head_hidden_dim": 128, + "dropout": 0.1, + "use_mpnn": true, + "mpnn_ensemble_paths": [ + "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt", + "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt", + "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt", + "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt", + "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt" + ], + "lr": 0.0001, + "weight_decay": 1e-05, + "batch_size": 64, + "epochs": 50, + "patience": 10 +} \ No newline at end of file diff --git a/models/pretrain_cv/fold_0/history.json b/models/pretrain_cv/fold_0/history.json new file mode 100644 index 0000000..2525e3d --- /dev/null +++ b/models/pretrain_cv/fold_0/history.json @@ -0,0 +1,154 @@ +[ + { + "epoch": 1, + "train_loss": 0.7841362829904129, + "val_loss": 0.9230800325498075, + "val_rmse": 0.9607705443260125, + "val_r2": 0.20751899565844845, + "lr": 0.0001 + }, + { + "epoch": 2, + "train_loss": 0.6463805462366321, + "val_loss": 0.9147881584284676, + "val_rmse": 0.9564455868655657, + "val_r2": 0.2146377239324061, + "lr": 0.0001 + }, + { + "epoch": 3, + "train_loss": 0.5832159742587586, + "val_loss": 0.8564491166460212, + "val_rmse": 0.9254453497069719, + "val_r2": 0.2647228727252957, + "lr": 0.0001 + }, + { + "epoch": 4, + "train_loss": 0.5554495169353261, + "val_loss": 0.9521924139543092, + "val_rmse": 0.9758034723689635, + "val_r2": 0.18252548972094385, + "lr": 0.0001 + }, + { + "epoch": 5, + "train_loss": 0.5284749799126208, + "val_loss": 0.9002334602240055, + "val_rmse": 0.9488063310427656, + "val_r2": 0.22713320447964735, + "lr": 0.0001 + }, + { + "epoch": 6, + "train_loss": 0.5008455182478028, + "val_loss": 0.8869625280360692, + "val_rmse": 0.9417868917406855, + "val_r2": 0.23852651728338659, + "lr": 0.0001 + }, + { + "epoch": 7, + "train_loss": 0.4812185961924912, + "val_loss": 0.8656982819227316, + "val_rmse": 0.9304290898141719, + "val_r2": 0.25678226981933505, + "lr": 0.0001 + }, + { + "epoch": 8, + "train_loss": 0.46716415395603483, + "val_loss": 0.8974583646500499, + "val_rmse": 0.9473427898099364, + "val_r2": 0.22951567180333565, + "lr": 0.0001 + }, + { + "epoch": 9, + "train_loss": 0.44227657012969945, + "val_loss": 0.8423525478929029, + "val_rmse": 0.917797665041863, + "val_r2": 0.2768250098825846, + "lr": 0.0001 + }, + { + "epoch": 10, + "train_loss": 0.4192214831283541, + "val_loss": 0.8834660291599854, + "val_rmse": 0.9399287338143215, + "val_r2": 0.24152834743071827, + "lr": 0.0001 + }, + { + "epoch": 11, + "train_loss": 0.4261854484762579, + "val_loss": 0.8901423802745356, + "val_rmse": 0.9434735654804721, + "val_r2": 0.23579658456782115, + "lr": 0.0001 + }, + { + "epoch": 12, + "train_loss": 0.40934500416213515, + "val_loss": 0.8902992367456848, + "val_rmse": 0.9435566978740579, + "val_r2": 0.2356619059496775, + "lr": 0.0001 + }, + { + "epoch": 13, + "train_loss": 0.3993083212922134, + "val_loss": 0.8952986713233702, + "val_rmse": 0.946202241268488, + "val_r2": 0.23136979638836497, + "lr": 0.0001 + }, + { + "epoch": 14, + "train_loss": 0.38611710345412054, + "val_loss": 0.898042505591281, + "val_rmse": 0.9476510417548493, + "val_r2": 0.22901418082171254, + "lr": 0.0001 + }, + { + "epoch": 15, + "train_loss": 0.3662278820512592, + "val_loss": 0.8985377061208856, + "val_rmse": 0.9479122944969771, + "val_r2": 0.2285890244824833, + "lr": 0.0001 + }, + { + "epoch": 16, + "train_loss": 0.33336883667983097, + "val_loss": 0.8594161927771942, + "val_rmse": 0.9270470292726928, + "val_r2": 0.26217556409941023, + "lr": 5e-05 + }, + { + "epoch": 17, + "train_loss": 0.33174229304183017, + "val_loss": 0.8942263270711439, + "val_rmse": 0.945635409387207, + "val_r2": 0.23229043171319064, + "lr": 5e-05 + }, + { + "epoch": 18, + "train_loss": 0.3200935872264761, + "val_loss": 0.8843722421096231, + "val_rmse": 0.9404106786278892, + "val_r2": 0.24075034122444328, + "lr": 5e-05 + }, + { + "epoch": 19, + "train_loss": 0.3140528901624022, + "val_loss": 0.8707393039336969, + "val_rmse": 0.9331341268749274, + "val_r2": 0.2524544731269053, + "lr": 5e-05 + } +] \ No newline at end of file diff --git a/models/pretrain_cv/fold_0/model.pt b/models/pretrain_cv/fold_0/model.pt new file mode 100644 index 0000000..a829bad Binary files /dev/null and b/models/pretrain_cv/fold_0/model.pt differ diff --git a/models/pretrain_cv/fold_1/history.json b/models/pretrain_cv/fold_1/history.json new file mode 100644 index 0000000..eb58ab9 --- /dev/null +++ b/models/pretrain_cv/fold_1/history.json @@ -0,0 +1,130 @@ +[ + { + "epoch": 1, + "train_loss": 0.7586579631889938, + "val_loss": 0.7181247283430661, + "val_rmse": 0.8474223956247073, + "val_r2": 0.27072978816447013, + "lr": 0.0001 + }, + { + "epoch": 2, + "train_loss": 0.6169086206378913, + "val_loss": 0.7122436411609591, + "val_rmse": 0.8439452896464879, + "val_r2": 0.27670212861312427, + "lr": 0.0001 + }, + { + "epoch": 3, + "train_loss": 0.5716009976577638, + "val_loss": 0.724399374180903, + "val_rmse": 0.8511165408178202, + "val_r2": 0.2643577544134178, + "lr": 0.0001 + }, + { + "epoch": 4, + "train_loss": 0.5406081575331337, + "val_loss": 0.8086859301147815, + "val_rmse": 0.8992696715048705, + "val_r2": 0.17876302728821758, + "lr": 0.0001 + }, + { + "epoch": 5, + "train_loss": 0.5032305184360722, + "val_loss": 0.7551608490501026, + "val_rmse": 0.8689999128971181, + "val_r2": 0.23311884509134373, + "lr": 0.0001 + }, + { + "epoch": 6, + "train_loss": 0.49779197957664073, + "val_loss": 0.7009093638175043, + "val_rmse": 0.8372033005770025, + "val_r2": 0.2882123252930564, + "lr": 0.0001 + }, + { + "epoch": 7, + "train_loss": 0.4720957023516007, + "val_loss": 0.7939711763762837, + "val_rmse": 0.8910506016905562, + "val_r2": 0.1937061718828037, + "lr": 0.0001 + }, + { + "epoch": 8, + "train_loss": 0.45880254612337634, + "val_loss": 0.7781202286020521, + "val_rmse": 0.8821112329914439, + "val_r2": 0.209803130396238, + "lr": 0.0001 + }, + { + "epoch": 9, + "train_loss": 0.43299420961863133, + "val_loss": 0.7305147618332145, + "val_rmse": 0.8547015607274253, + "val_r2": 0.25814744997562133, + "lr": 0.0001 + }, + { + "epoch": 10, + "train_loss": 0.4271428178197404, + "val_loss": 0.736324011356838, + "val_rmse": 0.8580932378454483, + "val_r2": 0.25224804192225836, + "lr": 0.0001 + }, + { + "epoch": 11, + "train_loss": 0.4049153788182662, + "val_loss": 0.7328011674777642, + "val_rmse": 0.8560380622891226, + "val_r2": 0.2558255581383374, + "lr": 0.0001 + }, + { + "epoch": 12, + "train_loss": 0.4011265030265379, + "val_loss": 0.8308254960889787, + "val_rmse": 0.9114962980441127, + "val_r2": 0.15627985591441118, + "lr": 0.0001 + }, + { + "epoch": 13, + "train_loss": 0.36400350090994316, + "val_loss": 0.7314743397036573, + "val_rmse": 0.8552627260066816, + "val_r2": 0.25717298455584947, + "lr": 5e-05 + }, + { + "epoch": 14, + "train_loss": 0.35224626399225445, + "val_loss": 0.7428758944520272, + "val_rmse": 0.8619024879725893, + "val_r2": 0.24559446076977343, + "lr": 5e-05 + }, + { + "epoch": 15, + "train_loss": 0.33604996105846857, + "val_loss": 0.7853316709722159, + "val_rmse": 0.8861894055728078, + "val_r2": 0.20247977173777976, + "lr": 5e-05 + }, + { + "epoch": 16, + "train_loss": 0.3375160908226994, + "val_loss": 0.7890364486366602, + "val_rmse": 0.888277234295553, + "val_r2": 0.19871749006650596, + "lr": 5e-05 + } +] \ No newline at end of file diff --git a/models/pretrain_cv/fold_1/model.pt b/models/pretrain_cv/fold_1/model.pt new file mode 100644 index 0000000..3b5c236 Binary files /dev/null and b/models/pretrain_cv/fold_1/model.pt differ diff --git a/models/pretrain_delivery.pt b/models/pretrain_delivery.pt index ad7a903..3675a40 100644 Binary files a/models/pretrain_delivery.pt and b/models/pretrain_delivery.pt differ diff --git a/models/pretrain_history.json b/models/pretrain_history.json index 81d0129..be21bf2 100644 --- a/models/pretrain_history.json +++ b/models/pretrain_history.json @@ -1,285 +1,309 @@ { "train": [ { - "loss": 0.8151603855680482, + "loss": 0.8244676398801744, "n_samples": 8721 }, { - "loss": 0.6867990841792819, + "loss": 0.6991508170533461, "n_samples": 8721 }, { - "loss": 0.645540308540888, + "loss": 0.6388374940987616, "n_samples": 8721 }, { - "loss": 0.5923176541020599, + "loss": 0.6008581508669937, "n_samples": 8721 }, { - "loss": 0.5720762926262872, + "loss": 0.584832567446085, "n_samples": 8721 }, { - "loss": 0.5477570670417328, + "loss": 0.5481657371815157, "n_samples": 8721 }, { - "loss": 0.5280393017717573, + "loss": 0.5368926340308079, "n_samples": 8721 }, { - "loss": 0.5122504676513313, + "loss": 0.5210388793613561, "n_samples": 8721 }, { - "loss": 0.49667307051028314, + "loss": 0.49758357966374045, "n_samples": 8721 }, { - "loss": 0.486139440352648, + "loss": 0.49256294099457043, "n_samples": 8721 }, { - "loss": 0.4749755339466122, + "loss": 0.4697267088016886, "n_samples": 8721 }, { - "loss": 0.4636757543530298, + "loss": 0.45763822707571084, "n_samples": 8721 }, { - "loss": 0.4543497681877452, + "loss": 0.4495221330627172, "n_samples": 8721 }, { - "loss": 0.4408158337956461, + "loss": 0.446159594079631, "n_samples": 8721 }, { - "loss": 0.4419790126221837, + "loss": 0.4327090857889029, "n_samples": 8721 }, { - "loss": 0.42850686623585116, + "loss": 0.4249273364101852, "n_samples": 8721 }, { - "loss": 0.41607048387867007, + "loss": 0.4216959138704459, "n_samples": 8721 }, { - "loss": 0.427172136486513, + "loss": 0.416526201182502, "n_samples": 8721 }, { - "loss": 0.4125568530569382, + "loss": 0.40368679039741573, "n_samples": 8721 }, { - "loss": 0.39480836287767923, + "loss": 0.4051084730032182, "n_samples": 8721 }, { - "loss": 0.3885056775666858, + "loss": 0.38971701020385785, "n_samples": 8721 }, { - "loss": 0.3894976457588827, + "loss": 0.39155546386038786, "n_samples": 8721 }, { - "loss": 0.3890058272899995, + "loss": 0.37976963541784114, "n_samples": 8721 }, { - "loss": 0.3741690826284791, + "loss": 0.36484339719805037, "n_samples": 8721 }, { - "loss": 0.3534914434345719, + "loss": 0.36232607571196496, "n_samples": 8721 }, { - "loss": 0.3349389765134386, + "loss": 0.3345973272380199, "n_samples": 8721 }, { - "loss": 0.32965143874976194, + "loss": 0.31767916518768957, "n_samples": 8721 }, { - "loss": 0.32094062546116675, + "loss": 0.32065429246052457, "n_samples": 8721 }, { - "loss": 0.32526135008251184, + "loss": 0.3171297926146043, "n_samples": 8721 }, { - "loss": 0.31289531808423826, + "loss": 0.3122120894173009, "n_samples": 8721 }, { - "loss": 0.3088379208288558, + "loss": 0.3135035038404461, "n_samples": 8721 }, { - "loss": 0.2994744991261045, + "loss": 0.2987745178222875, "n_samples": 8721 }, { - "loss": 0.2981521815160671, + "loss": 0.2914867957853393, "n_samples": 8721 }, { - "loss": 0.29143649446979303, + "loss": 0.2983839795507705, "n_samples": 8721 }, { - "loss": 0.29075756723379653, + "loss": 0.2826709597875678, + "n_samples": 8721 + }, + { + "loss": 0.2731766632569382, + "n_samples": 8721 + }, + { + "loss": 0.27726896305742266, + "n_samples": 8721 + }, + { + "loss": 0.27864557847067956, "n_samples": 8721 } ], "val": [ { - "loss": 0.7625711447683281, + "loss": 0.7601077516012517, "n_samples": 969 }, { - "loss": 0.7092331695236781, + "loss": 0.7119935319611901, "n_samples": 969 }, { - "loss": 0.7014068689723995, + "loss": 0.6461842978148269, "n_samples": 969 }, { - "loss": 0.6595172673863646, + "loss": 0.7006978391063226, "n_samples": 969 }, { - "loss": 0.6312279044905191, + "loss": 0.6533874032943979, "n_samples": 969 }, { - "loss": 0.6349272860831151, + "loss": 0.6413641451743611, "n_samples": 969 }, { - "loss": 0.6587623598744133, + "loss": 0.6168395132979742, "n_samples": 969 }, { - "loss": 0.6093261837651732, + "loss": 0.6095251602162025, "n_samples": 969 }, { - "loss": 0.6125607111474924, + "loss": 0.5887809592626905, "n_samples": 969 }, { - "loss": 0.6005943137518024, + "loss": 0.5655298325376368, "n_samples": 969 }, { - "loss": 0.6876292386783289, + "loss": 0.5809201743872788, "n_samples": 969 }, { - "loss": 0.5940848466228036, + "loss": 0.5897585974912033, "n_samples": 969 }, { - "loss": 0.5820883587079644, + "loss": 0.5732012489662573, "n_samples": 969 }, { - "loss": 0.6302792748938035, + "loss": 0.5607388911786094, "n_samples": 969 }, { - "loss": 0.5849901610914275, + "loss": 0.5717580675371414, "n_samples": 969 }, { - "loss": 0.5830434826428553, + "loss": 0.5553950037657291, "n_samples": 969 }, { - "loss": 0.5643168952858116, + "loss": 0.5778171792857049, "n_samples": 969 }, { - "loss": 0.5592790719340829, + "loss": 0.5602665468127734, "n_samples": 969 }, { - "loss": 0.600335100686833, + "loss": 0.5475307451359259, "n_samples": 969 }, { - "loss": 0.5646457721097674, + "loss": 0.551515599314827, "n_samples": 969 }, { - "loss": 0.6288956836004376, + "loss": 0.5755438121541243, "n_samples": 969 }, { - "loss": 0.5771863222183704, + "loss": 0.5798238261811381, "n_samples": 969 }, { - "loss": 0.5738056593250687, + "loss": 0.5739961433828923, "n_samples": 969 }, { - "loss": 0.5636531712085593, + "loss": 0.5742599932540312, "n_samples": 969 }, { - "loss": 0.5465074849879163, + "loss": 0.5834948123885382, "n_samples": 969 }, { - "loss": 0.5701294839843508, + "loss": 0.554078846570139, "n_samples": 969 }, { - "loss": 0.5570075802420438, + "loss": 0.5714933996322354, "n_samples": 969 }, { - "loss": 0.5711473401701241, + "loss": 0.5384107524350331, "n_samples": 969 }, { - "loss": 0.5576858864741552, + "loss": 0.570854394451568, "n_samples": 969 }, { - "loss": 0.5624132422716871, + "loss": 0.5767292551642478, "n_samples": 969 }, { - "loss": 0.5655298555506272, + "loss": 0.5660079547556808, "n_samples": 969 }, { - "loss": 0.5568078993151677, + "loss": 0.5608972411514312, "n_samples": 969 }, { - "loss": 0.567752199383958, + "loss": 0.5620947442987263, "n_samples": 969 }, { - "loss": 0.5683093779442603, + "loss": 0.5706970894361305, "n_samples": 969 }, { - "loss": 0.5741443767974497, + "loss": 0.5702376298690974, + "n_samples": 969 + }, + { + "loss": 0.5758474825259579, + "n_samples": 969 + }, + { + "loss": 0.5673816067284844, + "n_samples": 969 + }, + { + "loss": 0.5671441179879925, "n_samples": 969 } ] diff --git a/models/pretrain_test_results.json b/models/pretrain_test_results.json new file mode 100644 index 0000000..21a50df --- /dev/null +++ b/models/pretrain_test_results.json @@ -0,0 +1,18 @@ +{ + "model_path": "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/pretrain_delivery.pt", + "val_path": "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/data/processed/val_pretrain.parquet", + "n_samples": 969, + "metrics": { + "mse": 0.5384107828140259, + "rmse": 0.7337648007461423, + "mae": 0.5158410668373108, + "r2": 0.4877620374678041, + "correlation": 0.7059544816157891 + }, + "statistics": { + "y_true_mean": 0.009126194752752781, + "y_true_std": 1.0252292156219482, + "y_pred_mean": 0.008337541483342648, + "y_pred_std": 0.8293642997741699 + } +} \ No newline at end of file diff --git a/scripts/process_data_cv.py b/scripts/process_data_cv.py new file mode 100644 index 0000000..9f414b1 --- /dev/null +++ b/scripts/process_data_cv.py @@ -0,0 +1,234 @@ +"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式""" + +from pathlib import Path +from typing import Dict, List, Tuple + +import pandas as pd +import typer +from loguru import logger + +from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR +from lnp_ml.dataset import ( + LNPDatasetConfig, + COMP_COLS, + HELP_COLS, + get_phys_cols, + get_exp_cols, + EXP_ONEHOT_SPECS, + PHYS_ONEHOT_SPECS, +) + + +app = typer.Typer() + + +# CV extra_x 列名到模型列名的映射 +CV_COL_MAPPING = { + # Batch_or_individual_or_barcoded -> Sample_organization_type (for Value_name related) + "Batch_or_individual_or_barcoded_Barcoded": "Batch_or_individual_or_barcoded_Barcoded", + "Batch_or_individual_or_barcoded_Individual": "Batch_or_individual_or_barcoded_Individual", + # Helper_lipid_ID_None 不在模型中使用,忽略 +} + + +def load_cv_split(cv_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + 加载单个 CV split 的数据。 + + Args: + cv_dir: CV split 目录,如 cv_0/ + + Returns: + (train_df, valid_df, test_df) 合并后的 DataFrame + """ + splits = {} + for split_name in ["train", "valid", "test"]: + # 加载主数据(smiles, quantified_delivery) + main_path = cv_dir / f"{split_name}.csv" + extra_x_path = cv_dir / f"{split_name}_extra_x.csv" + metadata_path = cv_dir / f"{split_name}_metadata.csv" + + if not main_path.exists(): + raise FileNotFoundError(f"Missing {main_path}") + + main_df = pd.read_csv(main_path) + + # 加载 extra_x(已 one-hot 编码的特征) + if extra_x_path.exists(): + extra_x_df = pd.read_csv(extra_x_path) + # 确保行数一致 + assert len(main_df) == len(extra_x_df), f"Row count mismatch: {len(main_df)} vs {len(extra_x_df)}" + # 合并(按行索引) + df = pd.concat([main_df, extra_x_df], axis=1) + else: + df = main_df + logger.warning(f" {split_name}_extra_x.csv not found, using main data only") + + # 可选:从 metadata 获取额外信息 + if metadata_path.exists(): + metadata_df = pd.read_csv(metadata_path) + # 提取需要的列(如 Purity, Mix_type, Value_name 等) + for col in ["Purity", "Mix_type", "Value_name", "Target_or_delivered_gene"]: + if col in metadata_df.columns and col not in df.columns: + df[col] = metadata_df[col] + + splits[split_name] = df + + return splits["train"], splits["valid"], splits["test"] + + +def process_cv_dataframe(df: pd.DataFrame) -> pd.DataFrame: + """ + 处理 CV 数据的 DataFrame,对齐到模型所需的列格式。 + + CV 数据的 extra_x 已经包含大部分 one-hot 编码,但需要: + 1. 添加缺失的 one-hot 列(设为 0) + 2. 从 metadata 中生成 phys token 的 one-hot 列(Purity, Mix_type, Cargo_type, Target_or_delivered_gene) + 3. 生成 Value_name 的 one-hot 列 + """ + df = df.copy() + + # 1. 处理 comp 列 + for col in COMP_COLS: + if col in df.columns: + df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0) + else: + df[col] = 0.0 + + # 2. 处理 help 列 + for col in HELP_COLS: + if col not in df.columns: + df[col] = 0.0 + else: + df[col] = df[col].fillna(0.0).astype(float) + + # 3. 处理 phys token 的 one-hot 列 + for col, values in PHYS_ONEHOT_SPECS.items(): + for v in values: + onehot_col = f"{col}_{v}" + if onehot_col not in df.columns: + # 尝试从原始列生成 + if col in df.columns: + df[onehot_col] = (df[col] == v).astype(float) + else: + df[onehot_col] = 0.0 + else: + df[onehot_col] = df[onehot_col].fillna(0.0).astype(float) + + # 4. 处理 exp token 的 one-hot 列 + for col, values in EXP_ONEHOT_SPECS.items(): + for v in values: + onehot_col = f"{col}_{v}" + if onehot_col not in df.columns: + # 尝试从原始列生成 + if col in df.columns: + df[onehot_col] = (df[col] == v).astype(float) + else: + df[onehot_col] = 0.0 + else: + df[onehot_col] = df[onehot_col].fillna(0.0).astype(float) + + # 5. 处理 quantified_delivery + if "quantified_delivery" in df.columns: + df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce") + + return df + + +def get_feature_columns() -> List[str]: + """获取所有特征列名""" + config = LNPDatasetConfig() + return ( + ["smiles"] + + config.comp_cols + + config.phys_cols + + config.help_cols + + config.exp_cols + + ["quantified_delivery"] + ) + + +@app.command() +def main( + data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON", + output_dir: Path = PROCESSED_DATA_DIR / "cv", + n_folds: int = 5, +): + """ + 处理 cross-validation 数据,生成模型所需的 parquet 文件。 + + 输出结构: + - processed/cv/fold_0/train.parquet + - processed/cv/fold_0/valid.parquet + - processed/cv/fold_0/test.parquet + - processed/cv/fold_1/... + - processed/cv/feature_columns.txt + """ + logger.info(f"Processing CV data from {data_dir}") + + # 获取所有 cv_* 目录 + cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")]) + + if len(cv_dirs) == 0: + logger.error(f"No cv_* directories found in {data_dir}") + raise typer.Exit(1) + + if len(cv_dirs) != n_folds: + logger.warning(f"Expected {n_folds} folds, found {len(cv_dirs)}") + + logger.info(f"Found {len(cv_dirs)} folds: {[d.name for d in cv_dirs]}") + + feature_cols = get_feature_columns() + output_dir.mkdir(parents=True, exist_ok=True) + + for i, cv_dir in enumerate(cv_dirs): + fold_name = f"fold_{i}" + fold_output_dir = output_dir / fold_name + fold_output_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Processing {cv_dir.name} -> {fold_name}") + + # 加载数据 + train_df, valid_df, test_df = load_cv_split(cv_dir) + + logger.info(f" Loaded: train={len(train_df)}, valid={len(valid_df)}, test={len(test_df)}") + + # 处理数据 + train_df = process_cv_dataframe(train_df) + valid_df = process_cv_dataframe(valid_df) + test_df = process_cv_dataframe(test_df) + + # 确保所有列存在 + for col in feature_cols: + for df in [train_df, valid_df, test_df]: + if col not in df.columns: + df[col] = 0.0 if col != "smiles" else "" + + # 只保留需要的列 + train_df = train_df[feature_cols] + valid_df = valid_df[feature_cols] + test_df = test_df[feature_cols] + + # 保存 + train_df.to_parquet(fold_output_dir / "train.parquet", index=False) + valid_df.to_parquet(fold_output_dir / "valid.parquet", index=False) + test_df.to_parquet(fold_output_dir / "test.parquet", index=False) + + logger.success(f" Saved to {fold_output_dir}") + + # 保存特征列配置 + cols_path = output_dir / "feature_columns.txt" + with open(cols_path, "w") as f: + f.write("\n".join(feature_cols)) + logger.success(f"Saved feature columns to {cols_path}") + + logger.info("\n" + "=" * 60) + logger.info("CV DATA PROCESSING COMPLETE") + logger.info("=" * 60) + logger.info(f"Output directory: {output_dir}") + logger.info(f"Number of folds: {len(cv_dirs)}") + + +if __name__ == "__main__": + app() +