diff --git a/Makefile b/Makefile index e73b512..185712f 100644 --- a/Makefile +++ b/Makefile @@ -78,6 +78,11 @@ data_pretrain: requirements data_pretrain_cv: requirements $(PYTHON_INTERPRETER) scripts/process_external_cv.py +## Process internal data with amine-based CV splitting (interim -> processed/cv) +.PHONY: data_cv +data_cv: requirements + $(PYTHON_INTERPRETER) scripts/process_data_cv.py + # MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder # 例如:make pretrain USE_MPNN=1 MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,) @@ -120,6 +125,16 @@ train: requirements finetune: requirements $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG) +## Finetune with cross-validation on internal data (5-fold, amine-based split) with pretrained weights +.PHONY: finetune_cv +finetune_cv: requirements + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv main --init-from-pretrain models/pretrain_delivery.pt $(FREEZE_FLAG) $(MPNN_FLAG) $(DEVICE_FLAG) + +## Evaluate CV finetuned models on test sets (auto-detects MPNN from checkpoint) +.PHONY: test_cv +test_cv: requirements + $(PYTHON_INTERPRETER) -m lnp_ml.modeling.train_cv test $(DEVICE_FLAG) + ## Train with hyperparameter tuning .PHONY: tune tune: requirements diff --git a/data/processed/cv/feature_columns.txt b/data/processed/cv/feature_columns.txt new file mode 100644 index 0000000..eb4a748 --- /dev/null +++ b/data/processed/cv/feature_columns.txt @@ -0,0 +1,89 @@ +# Feature columns configuration + +# SMILES +smiles + +# comp token [5] +Cationic_Lipid_to_mRNA_weight_ratio +Cationic_Lipid_Mol_Ratio +Phospholipid_Mol_Ratio +Cholesterol_Mol_Ratio +PEG_Lipid_Mol_Ratio + +# phys token [12] +Purity_Pure +Purity_Crude +Mix_type_Microfluidic +Mix_type_Pipetting +Cargo_type_mRNA +Cargo_type_pDNA +Cargo_type_siRNA +Target_or_delivered_gene_FFL +Target_or_delivered_gene_Peptide_barcode +Target_or_delivered_gene_hEPO +Target_or_delivered_gene_FVII +Target_or_delivered_gene_GFP + +# help token [4] +Helper_lipid_ID_DOPE +Helper_lipid_ID_DOTAP +Helper_lipid_ID_DSPC +Helper_lipid_ID_MDOA + +# exp token [32] +Model_type_A549 +Model_type_BDMC +Model_type_BMDM +Model_type_HBEC_ALI +Model_type_HEK293T +Model_type_HeLa +Model_type_IGROV1 +Model_type_Mouse +Model_type_RAW264p7 +Delivery_target_body +Delivery_target_dendritic_cell +Delivery_target_generic_cell +Delivery_target_liver +Delivery_target_lung +Delivery_target_lung_epithelium +Delivery_target_macrophage +Delivery_target_muscle +Delivery_target_spleen +Route_of_administration_in_vitro +Route_of_administration_intramuscular +Route_of_administration_intratracheal +Route_of_administration_intravenous +Batch_or_individual_or_barcoded_Barcoded +Batch_or_individual_or_barcoded_Individual +Value_name_log_luminescence +Value_name_luminescence +Value_name_FFL_silencing +Value_name_Peptide_abundance +Value_name_hEPO +Value_name_FVII_silencing +Value_name_GFP_delivery +Value_name_Discretized_luminescence + +# Targets +## Regression +size +quantified_delivery +## PDI classification +PDI_0_0to0_2 +PDI_0_2to0_3 +PDI_0_3to0_4 +PDI_0_4to0_5 +## EE classification +Encapsulation_Efficiency_EE<50 +Encapsulation_Efficiency_50<=EE<80 +Encapsulation_Efficiency_80 List[str]: + """自动查找 MPNN ensemble 的 model.pt 文件。""" + model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt")) + if not model_paths: + raise FileNotFoundError(f"No model.pt files found in {base_dir}") + return [str(p) for p in model_paths] + + +app = typer.Typer() + + +def create_model( + d_model: int = 256, + num_heads: int = 8, + n_attn_layers: int = 4, + fusion_strategy: str = "attention", + head_hidden_dim: int = 128, + dropout: float = 0.1, + mpnn_checkpoint: Optional[str] = None, + mpnn_ensemble_paths: Optional[List[str]] = None, + mpnn_device: str = "cpu", +) -> Union[LNPModel, LNPModelWithoutMPNN]: + """创建模型(支持可选的 MPNN encoder)""" + use_mpnn = mpnn_checkpoint is not None or mpnn_ensemble_paths is not None + + if use_mpnn: + return LNPModel( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + mpnn_checkpoint=mpnn_checkpoint, + mpnn_ensemble_paths=mpnn_ensemble_paths, + mpnn_device=mpnn_device, + ) + else: + return LNPModelWithoutMPNN( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + ) + + +def train_fold( + fold_idx: int, + train_loader: DataLoader, + val_loader: DataLoader, + model: nn.Module, + device: torch.device, + output_dir: Path, + lr: float = 1e-4, + weight_decay: float = 1e-5, + epochs: int = 100, + patience: int = 15, + loss_weights: Optional[LossWeights] = None, + config: Optional[Dict] = None, +) -> Dict: + """训练单个 fold""" + logger.info(f"\n{'='*60}") + logger.info(f"Training Fold {fold_idx}") + logger.info(f"{'='*60}") + + model = model.to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=5 + ) + early_stopping = EarlyStopping(patience=patience) + + history = {"train": [], "val": []} + best_val_loss = float("inf") + best_state = None + + for epoch in range(epochs): + # Train + train_metrics = train_epoch(model, train_loader, optimizer, device, loss_weights) + + # Validate + val_metrics = validate(model, val_loader, device, loss_weights) + + current_lr = optimizer.param_groups[0]["lr"] + + # Log + logger.info( + f"Fold {fold_idx} Epoch {epoch+1}/{epochs} | " + f"Train Loss: {train_metrics['loss']:.4f} | " + f"Val Loss: {val_metrics['loss']:.4f} | " + f"LR: {current_lr:.2e}" + ) + + history["train"].append(train_metrics) + history["val"].append(val_metrics) + + # Learning rate scheduling + scheduler.step(val_metrics["loss"]) + + # Save best model + if val_metrics["loss"] < best_val_loss: + best_val_loss = val_metrics["loss"] + best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + logger.info(f" -> New best model (val_loss={best_val_loss:.4f})") + + # Early stopping + if early_stopping(val_metrics["loss"]): + logger.info(f"Early stopping at epoch {epoch+1}") + break + + # 保存最佳模型 + fold_output_dir = output_dir / f"fold_{fold_idx}" + fold_output_dir.mkdir(parents=True, exist_ok=True) + + checkpoint_path = fold_output_dir / "model.pt" + torch.save({ + "model_state_dict": best_state, + "config": config, + "best_val_loss": best_val_loss, + "fold_idx": fold_idx, + }, checkpoint_path) + logger.success(f"Saved fold {fold_idx} model to {checkpoint_path}") + + # 保存训练历史 + history_path = fold_output_dir / "history.json" + with open(history_path, "w") as f: + json.dump(history, f, indent=2) + + return { + "fold_idx": fold_idx, + "best_val_loss": best_val_loss, + "epochs_trained": len(history["train"]), + "final_train_loss": history["train"][-1]["loss"] if history["train"] else 0, + } + + +@app.command() +def main( + data_dir: Path = PROCESSED_DATA_DIR / "cv", + output_dir: Path = MODELS_DIR / "finetune_cv", + # 模型参数 + d_model: int = 256, + num_heads: int = 8, + n_attn_layers: int = 4, + fusion_strategy: str = "attention", + head_hidden_dim: int = 128, + dropout: float = 0.1, + # MPNN 参数(可选) + use_mpnn: bool = False, + mpnn_checkpoint: Optional[str] = None, + mpnn_ensemble_paths: Optional[str] = None, + mpnn_device: str = "cpu", + # 训练参数 + batch_size: int = 32, + lr: float = 1e-4, + weight_decay: float = 1e-5, + epochs: int = 100, + patience: int = 15, + # 预训练权重加载 + init_from_pretrain: Optional[Path] = None, + load_delivery_head: bool = True, + freeze_backbone: bool = False, + # 设备 + device: str = "cuda" if torch.cuda.is_available() else "cpu", +): + """ + 基于 Cross-Validation 训练 LNP 模型(多任务)。 + + 在 5-fold 内部数据上训练 5 个模型。 + + 使用 --use-mpnn 启用 MPNN encoder。 + 使用 --init-from-pretrain 从预训练 checkpoint 初始化。 + 使用 --freeze-backbone 冻结 backbone,只训练 heads。 + """ + logger.info(f"Using device: {device}") + device = torch.device(device) + + # 查找所有 fold 目录 + fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")]) + + if not fold_dirs: + logger.error(f"No fold_* directories found in {data_dir}") + logger.info("Please run 'make data_cv' first to process CV data.") + raise typer.Exit(1) + + logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}") + + output_dir.mkdir(parents=True, exist_ok=True) + + # 解析 MPNN 配置 + ensemble_paths_list = None + if mpnn_ensemble_paths: + ensemble_paths_list = mpnn_ensemble_paths.split(",") + elif use_mpnn and mpnn_checkpoint is None: + logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}") + ensemble_paths_list = find_mpnn_ensemble_paths() + logger.info(f"Found {len(ensemble_paths_list)} MPNN models") + + enable_mpnn = mpnn_checkpoint is not None or ensemble_paths_list is not None + + # 模型配置 + config = { + "d_model": d_model, + "num_heads": num_heads, + "n_attn_layers": n_attn_layers, + "fusion_strategy": fusion_strategy, + "head_hidden_dim": head_hidden_dim, + "dropout": dropout, + "use_mpnn": enable_mpnn, + "lr": lr, + "weight_decay": weight_decay, + "batch_size": batch_size, + "epochs": epochs, + "patience": patience, + "init_from_pretrain": str(init_from_pretrain) if init_from_pretrain else None, + "freeze_backbone": freeze_backbone, + } + + # 保存配置 + config_path = output_dir / "config.json" + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + logger.info(f"Saved config to {config_path}") + + # 加载预训练权重(如果指定) + pretrain_state = None + if init_from_pretrain is not None: + logger.info(f"Loading pretrain weights from {init_from_pretrain}") + checkpoint = torch.load(init_from_pretrain, map_location="cpu") + pretrain_config = checkpoint.get("config", {}) + if pretrain_config.get("d_model") != d_model: + logger.warning( + f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, " + f"current={d_model}. Skipping pretrain loading." + ) + else: + pretrain_state = checkpoint["model_state_dict"] + + # 训练每个 fold + fold_results = [] + + for fold_dir in tqdm(fold_dirs, desc="Training folds"): + fold_idx = int(fold_dir.name.split("_")[1]) + + # 加载数据 + train_df = pd.read_parquet(fold_dir / "train.parquet") + val_df = pd.read_parquet(fold_dir / "val.parquet") + + logger.info(f"\nFold {fold_idx}: train={len(train_df)}, val={len(val_df)}") + + # 创建 Dataset 和 DataLoader + train_dataset = LNPDataset(train_df) + val_dataset = LNPDataset(val_df) + + train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn + ) + val_loader = DataLoader( + val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn + ) + + # 创建新模型(每个 fold 独立初始化) + model = create_model( + d_model=d_model, + num_heads=num_heads, + n_attn_layers=n_attn_layers, + fusion_strategy=fusion_strategy, + head_hidden_dim=head_hidden_dim, + dropout=dropout, + mpnn_checkpoint=mpnn_checkpoint, + mpnn_ensemble_paths=ensemble_paths_list, + mpnn_device=device.type, + ) + + # 加载预训练权重 + if pretrain_state is not None: + model.load_pretrain_weights( + pretrain_state_dict=pretrain_state, + load_delivery_head=load_delivery_head, + strict=False, + ) + logger.info(f"Loaded pretrain weights (backbone + delivery_head={load_delivery_head})") + + # 冻结 backbone(如果指定) + if freeze_backbone: + frozen_count = 0 + for name, param in model.named_parameters(): + if name.startswith(("token_projector.", "cross_attention.", "fusion.")): + param.requires_grad = False + frozen_count += 1 + logger.info(f"Frozen {frozen_count} parameter tensors") + + # 打印模型信息(仅第一个 fold) + if fold_idx == 0: + n_params_total = sum(p.numel() for p in model.parameters()) + n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable") + + # 训练 + result = train_fold( + fold_idx=fold_idx, + train_loader=train_loader, + val_loader=val_loader, + model=model, + device=device, + output_dir=output_dir, + lr=lr, + weight_decay=weight_decay, + epochs=epochs, + patience=patience, + config=config, + ) + fold_results.append(result) + + # 汇总结果 + logger.info("\n" + "=" * 60) + logger.info("CROSS-VALIDATION TRAINING COMPLETE") + logger.info("=" * 60) + + val_losses = [r["best_val_loss"] for r in fold_results] + + logger.info(f"\n[Per-Fold Results]") + for r in fold_results: + logger.info( + f" Fold {r['fold_idx']}: " + f"Val Loss={r['best_val_loss']:.4f}, " + f"Epochs={r['epochs_trained']}" + ) + + logger.info(f"\n[Summary Statistics]") + logger.info(f" Val Loss: {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}") + + # 保存 CV 结果 + cv_results = { + "fold_results": fold_results, + "summary": { + "val_loss_mean": float(np.mean(val_losses)), + "val_loss_std": float(np.std(val_losses)), + }, + "config": config, + } + + results_path = output_dir / "cv_results.json" + with open(results_path, "w") as f: + json.dump(cv_results, f, indent=2) + logger.success(f"Saved CV results to {results_path}") + + +@app.command() +def test( + data_dir: Path = PROCESSED_DATA_DIR / "cv", + model_dir: Path = MODELS_DIR / "finetune_cv", + output_path: Path = MODELS_DIR / "finetune_cv" / "test_results.json", + batch_size: int = 64, + device: str = "cuda" if torch.cuda.is_available() else "cpu", +): + """ + 在测试集上评估 CV 训练的模型。 + + 使用每个 fold 的模型在对应的测试集上评估,然后汇总结果。 + """ + from sklearn.metrics import ( + mean_squared_error, + mean_absolute_error, + r2_score, + accuracy_score, + ) + + logger.info(f"Using device: {device}") + device = torch.device(device) + + # 查找所有 fold 目录 + fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")]) + + if not fold_dirs: + logger.error(f"No fold_* directories found in {data_dir}") + raise typer.Exit(1) + + logger.info(f"Found {len(fold_dirs)} folds") + + fold_results = [] + # 用于汇总所有 fold 的预测 + all_preds = { + "size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [] + } + all_targets = { + "size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [] + } + + for fold_dir in tqdm(fold_dirs, desc="Evaluating folds"): + fold_idx = int(fold_dir.name.split("_")[1]) + model_path = model_dir / f"fold_{fold_idx}" / "model.pt" + test_path = fold_dir / "test.parquet" + + if not model_path.exists(): + logger.warning(f"Fold {fold_idx}: model not found at {model_path}, skipping") + continue + + if not test_path.exists(): + logger.warning(f"Fold {fold_idx}: test data not found at {test_path}, skipping") + continue + + # 加载模型 + checkpoint = torch.load(model_path, map_location=device) + config = checkpoint["config"] + + use_mpnn = config.get("use_mpnn", False) + + # 总是重新查找 MPNN 路径 + if use_mpnn: + mpnn_paths = find_mpnn_ensemble_paths() + else: + mpnn_paths = None + + model = create_model( + d_model=config["d_model"], + num_heads=config["num_heads"], + n_attn_layers=config["n_attn_layers"], + fusion_strategy=config["fusion_strategy"], + head_hidden_dim=config["head_hidden_dim"], + dropout=config["dropout"], + mpnn_ensemble_paths=mpnn_paths, + mpnn_device=device.type, + ) + model.load_state_dict(checkpoint["model_state_dict"]) + model = model.to(device) + model.eval() + + # 加载测试数据 + test_df = pd.read_parquet(test_path) + test_dataset = LNPDataset(test_df) + test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn + ) + + # 收集当前 fold 的预测 + fold_preds = {k: [] for k in all_preds.keys()} + fold_targets = {k: [] for k in all_targets.keys()} + + with torch.no_grad(): + pbar = tqdm(test_loader, desc=f"Fold {fold_idx} [Test]", leave=False) + for batch in pbar: + smiles = batch["smiles"] + tabular = {k: v.to(device) for k, v in batch["tabular"].items()} + targets = batch["targets"] + masks = batch["mask"] + + outputs = model(smiles, tabular) + + # Size + if "size" in masks and masks["size"].any(): + mask = masks["size"] + fold_preds["size"].extend( + outputs["size"].squeeze(-1)[mask].cpu().numpy().tolist() + ) + fold_targets["size"].extend( + targets["size"][mask].cpu().numpy().tolist() + ) + + # Delivery + if "delivery" in masks and masks["delivery"].any(): + mask = masks["delivery"] + fold_preds["delivery"].extend( + outputs["delivery"].squeeze(-1)[mask].cpu().numpy().tolist() + ) + fold_targets["delivery"].extend( + targets["delivery"][mask].cpu().numpy().tolist() + ) + + # PDI (classification) + if "pdi" in masks and masks["pdi"].any(): + mask = masks["pdi"] + pdi_preds = outputs["pdi"][mask].argmax(dim=-1).cpu().numpy() + pdi_targets = targets["pdi"][mask].cpu().numpy() + fold_preds["pdi"].extend(pdi_preds.tolist()) + fold_targets["pdi"].extend(pdi_targets.tolist()) + + # EE (classification) + if "ee" in masks and masks["ee"].any(): + mask = masks["ee"] + ee_preds = outputs["ee"][mask].argmax(dim=-1).cpu().numpy() + ee_targets = targets["ee"][mask].cpu().numpy() + fold_preds["ee"].extend(ee_preds.tolist()) + fold_targets["ee"].extend(ee_targets.tolist()) + + # Toxic (classification) + if "toxic" in masks and masks["toxic"].any(): + mask = masks["toxic"] + toxic_preds = outputs["toxic"][mask].argmax(dim=-1).cpu().numpy() + toxic_targets = targets["toxic"][mask].cpu().numpy().astype(int) + fold_preds["toxic"].extend(toxic_preds.tolist()) + fold_targets["toxic"].extend(toxic_targets.tolist()) + + # 计算当前 fold 的指标 + fold_metrics = {"fold_idx": fold_idx, "n_samples": len(test_df)} + + # 回归任务指标 + for task in ["size", "delivery"]: + if fold_preds[task]: + p = np.array(fold_preds[task]) + t = np.array(fold_targets[task]) + fold_metrics[task] = { + "n": len(p), + "rmse": float(np.sqrt(mean_squared_error(t, p))), + "mae": float(mean_absolute_error(t, p)), + "r2": float(r2_score(t, p)), + } + + # 分类任务指标 + for task in ["pdi", "ee", "toxic"]: + if fold_preds[task]: + p = np.array(fold_preds[task]) + t = np.array(fold_targets[task]) + fold_metrics[task] = { + "n": len(p), + "accuracy": float(accuracy_score(t, p)), + } + + fold_results.append(fold_metrics) + + # 汇总到全局 + for task in all_preds.keys(): + all_preds[task].extend(fold_preds[task]) + all_targets[task].extend(fold_targets[task]) + + # 打印当前 fold 结果 + log_parts = [f"Fold {fold_idx}: n={len(test_df)}"] + for task in ["delivery", "size"]: + if task in fold_metrics and isinstance(fold_metrics[task], dict): + log_parts.append(f"{task}_RMSE={fold_metrics[task]['rmse']:.4f}") + log_parts.append(f"{task}_R²={fold_metrics[task]['r2']:.4f}") + for task in ["pdi", "ee", "toxic"]: + if task in fold_metrics and isinstance(fold_metrics[task], dict): + log_parts.append(f"{task}_acc={fold_metrics[task]['accuracy']:.4f}") + logger.info(", ".join(log_parts)) + + # 计算跨 fold 汇总统计 + summary_stats = {} + for task in ["size", "delivery"]: + rmses = [r[task]["rmse"] for r in fold_results if task in r and isinstance(r[task], dict)] + r2s = [r[task]["r2"] for r in fold_results if task in r and isinstance(r[task], dict)] + if rmses: + summary_stats[task] = { + "rmse_mean": float(np.mean(rmses)), + "rmse_std": float(np.std(rmses)), + "r2_mean": float(np.mean(r2s)), + "r2_std": float(np.std(r2s)), + } + + for task in ["pdi", "ee", "toxic"]: + accs = [r[task]["accuracy"] for r in fold_results if task in r and isinstance(r[task], dict)] + if accs: + summary_stats[task] = { + "accuracy_mean": float(np.mean(accs)), + "accuracy_std": float(np.std(accs)), + } + + # 计算整体 pooled 指标 + overall = {} + for task in ["size", "delivery"]: + if all_preds[task]: + p = np.array(all_preds[task]) + t = np.array(all_targets[task]) + overall[task] = { + "n_samples": len(p), + "mse": float(mean_squared_error(t, p)), + "rmse": float(np.sqrt(mean_squared_error(t, p))), + "mae": float(mean_absolute_error(t, p)), + "r2": float(r2_score(t, p)), + } + + for task in ["pdi", "ee", "toxic"]: + if all_preds[task]: + p = np.array(all_preds[task]) + t = np.array(all_targets[task]) + overall[task] = { + "n_samples": len(p), + "accuracy": float(accuracy_score(t, p)), + } + + # 打印汇总结果 + logger.info("\n" + "=" * 60) + logger.info("CV TEST EVALUATION RESULTS") + logger.info("=" * 60) + + logger.info(f"\n[Summary Statistics (across {len(fold_results)} folds)]") + for task, stats in summary_stats.items(): + if "rmse_mean" in stats: + logger.info(f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}") + else: + logger.info(f" {task}: Accuracy={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}") + + logger.info(f"\n[Overall (all samples pooled)]") + for task, metrics in overall.items(): + if "rmse" in metrics: + logger.info(f" {task} (n={metrics['n_samples']}): RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, R²={metrics['r2']:.4f}") + else: + logger.info(f" {task} (n={metrics['n_samples']}): Accuracy={metrics['accuracy']:.4f}") + + # 保存结果 + results = { + "fold_results": fold_results, + "summary_stats": summary_stats, + "overall": overall, + } + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + logger.success(f"\nSaved test results to {output_path}") + + +if __name__ == "__main__": + app() + diff --git a/models/finetune_cv/config.json b/models/finetune_cv/config.json new file mode 100644 index 0000000..ac79f18 --- /dev/null +++ b/models/finetune_cv/config.json @@ -0,0 +1,16 @@ +{ + "d_model": 256, + "num_heads": 8, + "n_attn_layers": 4, + "fusion_strategy": "attention", + "head_hidden_dim": 128, + "dropout": 0.1, + "use_mpnn": true, + "lr": 0.0001, + "weight_decay": 1e-05, + "batch_size": 32, + "epochs": 100, + "patience": 15, + "init_from_pretrain": "models/pretrain_delivery.pt", + "freeze_backbone": false +} \ No newline at end of file diff --git a/models/finetune_cv/cv_results.json b/models/finetune_cv/cv_results.json new file mode 100644 index 0000000..ebedc3e --- /dev/null +++ b/models/finetune_cv/cv_results.json @@ -0,0 +1,54 @@ +{ + "fold_results": [ + { + "fold_idx": 0, + "best_val_loss": 6.144314289093018, + "epochs_trained": 25, + "final_train_loss": 1.4692220211029052 + }, + { + "fold_idx": 1, + "best_val_loss": 8.569346030553183, + "epochs_trained": 20, + "final_train_loss": 1.5929443359375 + }, + { + "fold_idx": 2, + "best_val_loss": 3.7409281730651855, + "epochs_trained": 22, + "final_train_loss": 1.9401288827260335 + }, + { + "fold_idx": 3, + "best_val_loss": 3.47284197807312, + "epochs_trained": 27, + "final_train_loss": 1.8295514345169068 + }, + { + "fold_idx": 4, + "best_val_loss": 2.756531000137329, + "epochs_trained": 19, + "final_train_loss": 1.9399811571294612 + } + ], + "summary": { + "val_loss_mean": 4.936792294184367, + "val_loss_std": 2.1438440638412697 + }, + "config": { + "d_model": 256, + "num_heads": 8, + "n_attn_layers": 4, + "fusion_strategy": "attention", + "head_hidden_dim": 128, + "dropout": 0.1, + "use_mpnn": true, + "lr": 0.0001, + "weight_decay": 1e-05, + "batch_size": 32, + "epochs": 100, + "patience": 15, + "init_from_pretrain": "models/pretrain_delivery.pt", + "freeze_backbone": false + } +} \ No newline at end of file diff --git a/models/finetune_cv/fold_0/history.json b/models/finetune_cv/fold_0/history.json new file mode 100644 index 0000000..6b58a67 --- /dev/null +++ b/models/finetune_cv/fold_0/history.json @@ -0,0 +1,531 @@ +{ + "train": [ + { + "loss": 19.238220310211183, + "loss_size": 14.334759521484376, + "loss_pdi": 1.275341796875, + "loss_ee": 1.078886091709137, + "loss_delivery": 0.639056247472763, + "loss_biodist": 1.314099133014679, + "loss_toxic": 0.5960778951644897 + }, + { + "loss": 7.8008105754852295, + "loss_size": 3.835961139202118, + "loss_pdi": 1.0304630517959594, + "loss_ee": 1.002296370267868, + "loss_delivery": 0.527982234954834, + "loss_biodist": 1.0791441202163696, + "loss_toxic": 0.3249638095498085 + }, + { + "loss": 3.952784705162048, + "loss_size": 0.5930101618170738, + "loss_pdi": 0.7961886763572693, + "loss_ee": 0.9416749358177186, + "loss_delivery": 0.5073600560426712, + "loss_biodist": 0.9171513140201568, + "loss_toxic": 0.19739954844117164 + }, + { + "loss": 3.218132185935974, + "loss_size": 0.20842453986406326, + "loss_pdi": 0.688220864534378, + "loss_ee": 0.904784232378006, + "loss_delivery": 0.4910900041460991, + "loss_biodist": 0.792213362455368, + "loss_toxic": 0.1333992186933756 + }, + { + "loss": 2.930907416343689, + "loss_size": 0.21291286423802375, + "loss_pdi": 0.6122969090938568, + "loss_ee": 0.8774014472961426, + "loss_delivery": 0.4451231583952904, + "loss_biodist": 0.6634868443012237, + "loss_toxic": 0.11968618221580982 + }, + { + "loss": 2.7193881273269653, + "loss_size": 0.213371854275465, + "loss_pdi": 0.5864351749420166, + "loss_ee": 0.8563571333885193, + "loss_delivery": 0.3963193610310555, + "loss_biodist": 0.5644777715206146, + "loss_toxic": 0.10242685079574584 + }, + { + "loss": 2.5172106266021728, + "loss_size": 0.23241330087184905, + "loss_pdi": 0.5564007371664047, + "loss_ee": 0.8135433554649353, + "loss_delivery": 0.39191135168075564, + "loss_biodist": 0.4503189116716385, + "loss_toxic": 0.07262293715029955 + }, + { + "loss": 2.3014568209648134, + "loss_size": 0.1924597330391407, + "loss_pdi": 0.5315678030252456, + "loss_ee": 0.8107137799263, + "loss_delivery": 0.33130097687244414, + "loss_biodist": 0.3714979439973831, + "loss_toxic": 0.06391655802726745 + }, + { + "loss": 2.1527106881141664, + "loss_size": 0.19257416054606438, + "loss_pdi": 0.5191590428352356, + "loss_ee": 0.783897054195404, + "loss_delivery": 0.29573799669742584, + "loss_biodist": 0.3145490542054176, + "loss_toxic": 0.046793402079492806 + }, + { + "loss": 2.0622685074806215, + "loss_size": 0.2051038146018982, + "loss_pdi": 0.49145313203334806, + "loss_ee": 0.7647641122341156, + "loss_delivery": 0.2876307189464569, + "loss_biodist": 0.27712231278419497, + "loss_toxic": 0.03619444826617837 + }, + { + "loss": 1.9519578456878661, + "loss_size": 0.17994399815797807, + "loss_pdi": 0.4814375311136246, + "loss_ee": 0.733842009305954, + "loss_delivery": 0.28253656476736067, + "loss_biodist": 0.24782671630382538, + "loss_toxic": 0.026371066551655532 + }, + { + "loss": 1.935675847530365, + "loss_size": 0.1704096481204033, + "loss_pdi": 0.47338791787624357, + "loss_ee": 0.7182988226413727, + "loss_delivery": 0.3093330509960651, + "loss_biodist": 0.2340244770050049, + "loss_toxic": 0.030221952823922038 + }, + { + "loss": 1.888454306125641, + "loss_size": 0.17727438509464263, + "loss_pdi": 0.46344051957130433, + "loss_ee": 0.7103636503219605, + "loss_delivery": 0.30027762055397034, + "loss_biodist": 0.2190815806388855, + "loss_toxic": 0.018016549991443753 + }, + { + "loss": 1.8231052160263062, + "loss_size": 0.1548917345702648, + "loss_pdi": 0.4576862633228302, + "loss_ee": 0.7034903109073639, + "loss_delivery": 0.29063438922166823, + "loss_biodist": 0.19972888082265855, + "loss_toxic": 0.016673638485372066 + }, + { + "loss": 1.756770372390747, + "loss_size": 0.15216425359249114, + "loss_pdi": 0.429460334777832, + "loss_ee": 0.6776757568120957, + "loss_delivery": 0.2867794781923294, + "loss_biodist": 0.19385820478200913, + "loss_toxic": 0.01683232020586729 + }, + { + "loss": 1.7031532883644105, + "loss_size": 0.15495768785476685, + "loss_pdi": 0.43094243109226227, + "loss_ee": 0.6677232623100281, + "loss_delivery": 0.26152765452861787, + "loss_biodist": 0.17694738060235976, + "loss_toxic": 0.011054831324145198 + }, + { + "loss": 1.679247748851776, + "loss_size": 0.15252191424369813, + "loss_pdi": 0.40398688316345216, + "loss_ee": 0.6563315153121948, + "loss_delivery": 0.2827941685914993, + "loss_biodist": 0.17421896755695343, + "loss_toxic": 0.009394321008585393 + }, + { + "loss": 1.6231786251068114, + "loss_size": 0.14685654938220977, + "loss_pdi": 0.3987069964408875, + "loss_ee": 0.6459777146577835, + "loss_delivery": 0.2517095260322094, + "loss_biodist": 0.1695146732032299, + "loss_toxic": 0.010413196869194508 + }, + { + "loss": 1.5647669196128846, + "loss_size": 0.12480136081576347, + "loss_pdi": 0.40768158435821533, + "loss_ee": 0.6228045016527176, + "loss_delivery": 0.23313914462924004, + "loss_biodist": 0.1658004455268383, + "loss_toxic": 0.0105398821644485 + }, + { + "loss": 1.543732750415802, + "loss_size": 0.12116749435663224, + "loss_pdi": 0.3942162901163101, + "loss_ee": 0.6175289869308471, + "loss_delivery": 0.2506958607584238, + "loss_biodist": 0.15138662829995156, + "loss_toxic": 0.008737476798705757 + }, + { + "loss": 1.534558892250061, + "loss_size": 0.11713396161794662, + "loss_pdi": 0.39247085303068163, + "loss_ee": 0.608239871263504, + "loss_delivery": 0.24818528145551683, + "loss_biodist": 0.15680191665887833, + "loss_toxic": 0.011727035511285067 + }, + { + "loss": 1.508529245853424, + "loss_size": 0.12370343580842018, + "loss_pdi": 0.37536335587501524, + "loss_ee": 0.5983373761177063, + "loss_delivery": 0.2428302437067032, + "loss_biodist": 0.16056786775588988, + "loss_toxic": 0.0077269634697586295 + }, + { + "loss": 1.495661163330078, + "loss_size": 0.12558167055249214, + "loss_pdi": 0.3853023111820221, + "loss_ee": 0.5909300297498703, + "loss_delivery": 0.22781415805220603, + "loss_biodist": 0.15893371179699897, + "loss_toxic": 0.007099285908043385 + }, + { + "loss": 1.4543131768703461, + "loss_size": 0.11276061460375786, + "loss_pdi": 0.3703819438815117, + "loss_ee": 0.5739975512027741, + "loss_delivery": 0.24190589040517807, + "loss_biodist": 0.14822293519973756, + "loss_toxic": 0.007044258969835937 + }, + { + "loss": 1.4692220211029052, + "loss_size": 0.12677552737295628, + "loss_pdi": 0.37003436386585237, + "loss_ee": 0.578819090127945, + "loss_delivery": 0.2372341684997082, + "loss_biodist": 0.1498504839837551, + "loss_toxic": 0.0065083671128377315 + } + ], + "val": [ + { + "loss": 16.045034408569336, + "loss_size": 7.810531139373779, + "loss_pdi": 1.0906392335891724, + "loss_ee": 1.0121623277664185, + "loss_delivery": 4.599453926086426, + "loss_biodist": 1.121813416481018, + "loss_toxic": 0.4104346036911011, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.6296296296296297, + "acc_toxic": 1.0 + }, + { + "loss": 8.368906021118164, + "loss_size": 1.3932932615280151, + "loss_pdi": 0.7484031915664673, + "loss_ee": 0.9992449283599854, + "loss_delivery": 3.9427361488342285, + "loss_biodist": 1.0493215322494507, + "loss_toxic": 0.2359071671962738, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.77295446395874, + "loss_size": 0.07529407739639282, + "loss_pdi": 0.5075266361236572, + "loss_ee": 0.9870877265930176, + "loss_delivery": 4.065003395080566, + "loss_biodist": 1.0452064275741577, + "loss_toxic": 0.09283570945262909, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 7.080923080444336, + "loss_size": 0.06525272130966187, + "loss_pdi": 0.4343451261520386, + "loss_ee": 1.0161277055740356, + "loss_delivery": 4.559023380279541, + "loss_biodist": 0.9523851871490479, + "loss_toxic": 0.05378875508904457, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.916051864624023, + "loss_size": 0.07201710343360901, + "loss_pdi": 0.4837420880794525, + "loss_ee": 1.0075831413269043, + "loss_delivery": 4.492982864379883, + "loss_biodist": 0.7963473796844482, + "loss_toxic": 0.0633791983127594, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.480402946472168, + "loss_size": 0.056911345571279526, + "loss_pdi": 0.4535364508628845, + "loss_ee": 1.00966477394104, + "loss_delivery": 4.184553146362305, + "loss_biodist": 0.6998977661132812, + "loss_toxic": 0.07583901286125183, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.857071399688721, + "loss_size": 0.04192754626274109, + "loss_pdi": 0.44999659061431885, + "loss_ee": 1.0004281997680664, + "loss_delivery": 4.647970676422119, + "loss_biodist": 0.6309637427330017, + "loss_toxic": 0.0857851505279541, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.3087968826293945, + "loss_size": 0.09944896399974823, + "loss_pdi": 0.40266019105911255, + "loss_ee": 0.9705824255943298, + "loss_delivery": 4.246408939361572, + "loss_biodist": 0.552162230014801, + "loss_toxic": 0.03753397986292839, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.9212870597839355, + "loss_size": 0.05665857717394829, + "loss_pdi": 0.4618251323699951, + "loss_ee": 0.9701217412948608, + "loss_delivery": 4.835049152374268, + "loss_biodist": 0.5396986603736877, + "loss_toxic": 0.057933975011110306, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.144314289093018, + "loss_size": 0.16736657917499542, + "loss_pdi": 0.43004411458969116, + "loss_ee": 0.9552963972091675, + "loss_delivery": 4.0627570152282715, + "loss_biodist": 0.5059604048728943, + "loss_toxic": 0.02288975566625595, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 7.03376579284668, + "loss_size": 0.11119232326745987, + "loss_pdi": 0.4463132619857788, + "loss_ee": 0.9212661385536194, + "loss_delivery": 5.010645866394043, + "loss_biodist": 0.5135870575904846, + "loss_toxic": 0.03076130710542202, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.251324653625488, + "loss_size": 0.16327497363090515, + "loss_pdi": 0.4076344966888428, + "loss_ee": 0.9357188940048218, + "loss_delivery": 4.216032981872559, + "loss_biodist": 0.5158465504646301, + "loss_toxic": 0.012816602364182472, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.6296296296296297, + "acc_toxic": 1.0 + }, + { + "loss": 6.909743309020996, + "loss_size": 0.10885120183229446, + "loss_pdi": 0.40938591957092285, + "loss_ee": 0.8893271684646606, + "loss_delivery": 4.9792799949646, + "loss_biodist": 0.49506261944770813, + "loss_toxic": 0.027835864573717117, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.6296296296296297, + "acc_toxic": 1.0 + }, + { + "loss": 6.724283695220947, + "loss_size": 0.10787779092788696, + "loss_pdi": 0.45569828152656555, + "loss_ee": 0.979951798915863, + "loss_delivery": 4.641026020050049, + "loss_biodist": 0.5102843642234802, + "loss_toxic": 0.029445137828588486, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5555555555555556, + "acc_toxic": 1.0 + }, + { + "loss": 6.498813629150391, + "loss_size": 0.14617349207401276, + "loss_pdi": 0.3836137652397156, + "loss_ee": 0.8910271525382996, + "loss_delivery": 4.627362251281738, + "loss_biodist": 0.43996545672416687, + "loss_toxic": 0.010672268457710743, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.863435745239258, + "loss_size": 0.1379111111164093, + "loss_pdi": 0.5173991322517395, + "loss_ee": 1.013482689857483, + "loss_delivery": 4.7425923347473145, + "loss_biodist": 0.4353649318218231, + "loss_toxic": 0.01668536849319935, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.48148148148148145, + "acc_toxic": 1.0 + }, + { + "loss": 6.909954071044922, + "loss_size": 0.35534632205963135, + "loss_pdi": 0.3974299728870392, + "loss_ee": 0.8724083304405212, + "loss_delivery": 4.841543674468994, + "loss_biodist": 0.42888545989990234, + "loss_toxic": 0.014340322464704514, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.6296296296296297, + "acc_toxic": 1.0 + }, + { + "loss": 6.718777179718018, + "loss_size": 0.10662573575973511, + "loss_pdi": 0.4000368118286133, + "loss_ee": 0.925907552242279, + "loss_delivery": 4.855783939361572, + "loss_biodist": 0.4174629747867584, + "loss_toxic": 0.012960433959960938, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.800571918487549, + "loss_size": 0.20526628196239471, + "loss_pdi": 0.42215225100517273, + "loss_ee": 0.9293471574783325, + "loss_delivery": 4.7998528480529785, + "loss_biodist": 0.42879122495651245, + "loss_toxic": 0.015162378549575806, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.873208522796631, + "loss_size": 0.1934625804424286, + "loss_pdi": 0.45854493975639343, + "loss_ee": 0.9394232630729675, + "loss_delivery": 4.830544471740723, + "loss_biodist": 0.43078526854515076, + "loss_toxic": 0.020448585972189903, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.6296296296296297, + "acc_toxic": 1.0 + }, + { + "loss": 6.790691375732422, + "loss_size": 0.18806225061416626, + "loss_pdi": 0.43740272521972656, + "loss_ee": 0.9507966637611389, + "loss_delivery": 4.783614158630371, + "loss_biodist": 0.4198465943336487, + "loss_toxic": 0.010969163849949837, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 6.847814083099365, + "loss_size": 0.22717446088790894, + "loss_pdi": 0.42288023233413696, + "loss_ee": 0.9286114573478699, + "loss_delivery": 4.81552791595459, + "loss_biodist": 0.43471890687942505, + "loss_toxic": 0.01890140399336815, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.5925925925925926, + "acc_toxic": 1.0 + }, + { + "loss": 7.032270431518555, + "loss_size": 0.24087895452976227, + "loss_pdi": 0.46614596247673035, + "loss_ee": 0.9626513123512268, + "loss_delivery": 4.897539138793945, + "loss_biodist": 0.44117382168769836, + "loss_toxic": 0.023880867287516594, + "acc_pdi": 0.6666666666666666, + "acc_ee": 0.6296296296296297, + "acc_toxic": 1.0 + }, + { + "loss": 6.840210914611816, + "loss_size": 0.18546245992183685, + "loss_pdi": 0.41406312584877014, + "loss_ee": 0.9308509230613708, + "loss_delivery": 4.877482891082764, + "loss_biodist": 0.41710007190704346, + "loss_toxic": 0.015251458622515202, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.6296296296296297, + "acc_toxic": 1.0 + }, + { + "loss": 6.980130672454834, + "loss_size": 0.21011953055858612, + "loss_pdi": 0.4370857775211334, + "loss_ee": 0.9217703938484192, + "loss_delivery": 4.956975936889648, + "loss_biodist": 0.4334105849266052, + "loss_toxic": 0.020767945796251297, + "acc_pdi": 0.8888888888888888, + "acc_ee": 0.6296296296296297, + "acc_toxic": 1.0 + } + ] +} \ No newline at end of file diff --git a/models/finetune_cv/fold_0/model.pt b/models/finetune_cv/fold_0/model.pt new file mode 100644 index 0000000..9e8d77d Binary files /dev/null and b/models/finetune_cv/fold_0/model.pt differ diff --git a/models/finetune_cv/fold_1/history.json b/models/finetune_cv/fold_1/history.json new file mode 100644 index 0000000..0873c61 --- /dev/null +++ b/models/finetune_cv/fold_1/history.json @@ -0,0 +1,426 @@ +{ + "train": [ + { + "loss": 22.032494735717773, + "loss_size": 16.743953514099122, + "loss_pdi": 1.2239572763442994, + "loss_ee": 1.0474397182464599, + "loss_delivery": 1.2612280070781707, + "loss_biodist": 1.1997641563415526, + "loss_toxic": 0.5561524748802185 + }, + { + "loss": 12.498378562927247, + "loss_size": 8.05031099319458, + "loss_pdi": 1.0179082751274109, + "loss_ee": 0.9407736420631408, + "loss_delivery": 0.8306634247303009, + "loss_biodist": 1.160559320449829, + "loss_toxic": 0.4981633126735687 + }, + { + "loss": 6.852249622344971, + "loss_size": 2.8376791954040526, + "loss_pdi": 0.8515164494514466, + "loss_ee": 0.8604451298713685, + "loss_delivery": 0.7656278729438781, + "loss_biodist": 1.1427257061004639, + "loss_toxic": 0.39425510764122007 + }, + { + "loss": 4.40604076385498, + "loss_size": 0.6938439428806304, + "loss_pdi": 0.6752870678901672, + "loss_ee": 0.8080484986305236, + "loss_delivery": 0.8313614726066589, + "loss_biodist": 1.0828368663787842, + "loss_toxic": 0.31466284990310667 + }, + { + "loss": 3.4253986835479737, + "loss_size": 0.16804716736078262, + "loss_pdi": 0.5639804005622864, + "loss_ee": 0.6901269078254699, + "loss_delivery": 0.6912416338920593, + "loss_biodist": 1.0242016911506653, + "loss_toxic": 0.28780081272125246 + }, + { + "loss": 3.2362718105316164, + "loss_size": 0.1483217939734459, + "loss_pdi": 0.6221840143203735, + "loss_ee": 0.6781063556671143, + "loss_delivery": 0.6013748198747635, + "loss_biodist": 0.9614210963249207, + "loss_toxic": 0.22486359924077987 + }, + { + "loss": 3.1696151733398437, + "loss_size": 0.1722302332520485, + "loss_pdi": 0.48496800661087036, + "loss_ee": 0.6616590738296508, + "loss_delivery": 0.7679106175899506, + "loss_biodist": 0.9077507495880127, + "loss_toxic": 0.17509644329547883 + }, + { + "loss": 2.6617531299591066, + "loss_size": 0.16581893265247344, + "loss_pdi": 0.4746619284152985, + "loss_ee": 0.630395919084549, + "loss_delivery": 0.3917909190058708, + "loss_biodist": 0.852727723121643, + "loss_toxic": 0.1463577665388584 + }, + { + "loss": 2.4909090995788574, + "loss_size": 0.11310702562332153, + "loss_pdi": 0.43855146765708924, + "loss_ee": 0.5929172158241272, + "loss_delivery": 0.42325166761875155, + "loss_biodist": 0.7738024115562439, + "loss_toxic": 0.1492793083190918 + }, + { + "loss": 2.3516653537750245, + "loss_size": 0.21800988018512726, + "loss_pdi": 0.43570560216903687, + "loss_ee": 0.5719826459884644, + "loss_delivery": 0.3004884377121925, + "loss_biodist": 0.6908805131912231, + "loss_toxic": 0.1345983102917671 + }, + { + "loss": 2.1030978202819823, + "loss_size": 0.10880238711833953, + "loss_pdi": 0.40215051770210264, + "loss_ee": 0.5412053823471069, + "loss_delivery": 0.2829424023628235, + "loss_biodist": 0.651984566450119, + "loss_toxic": 0.11601254418492317 + }, + { + "loss": 1.994719886779785, + "loss_size": 0.08978431597352028, + "loss_pdi": 0.4142456531524658, + "loss_ee": 0.532235836982727, + "loss_delivery": 0.2747663021087646, + "loss_biodist": 0.5863585829734802, + "loss_toxic": 0.09732922576367856 + }, + { + "loss": 1.9875550031661988, + "loss_size": 0.13663371056318283, + "loss_pdi": 0.3811588704586029, + "loss_ee": 0.5030429780483245, + "loss_delivery": 0.2990179345011711, + "loss_biodist": 0.5580029547214508, + "loss_toxic": 0.10969849154353142 + }, + { + "loss": 1.9142925500869752, + "loss_size": 0.12668041437864302, + "loss_pdi": 0.4271237254142761, + "loss_ee": 0.5021511077880859, + "loss_delivery": 0.2249750167131424, + "loss_biodist": 0.5295105099678039, + "loss_toxic": 0.1038517564535141 + }, + { + "loss": 1.81255943775177, + "loss_size": 0.1435435712337494, + "loss_pdi": 0.3587616056203842, + "loss_ee": 0.5099679112434388, + "loss_delivery": 0.24120242595672609, + "loss_biodist": 0.48327294588088987, + "loss_toxic": 0.07581093870103359 + }, + { + "loss": 1.7264988660812377, + "loss_size": 0.1290398582816124, + "loss_pdi": 0.3462284058332443, + "loss_ee": 0.48116588592529297, + "loss_delivery": 0.22381204068660737, + "loss_biodist": 0.47818992137908933, + "loss_toxic": 0.06806278452277184 + }, + { + "loss": 1.6520193338394165, + "loss_size": 0.11937985867261887, + "loss_pdi": 0.34835702180862427, + "loss_ee": 0.4374507278203964, + "loss_delivery": 0.24701516777276994, + "loss_biodist": 0.43337869048118594, + "loss_toxic": 0.06643788442015648 + }, + { + "loss": 1.5883747339248657, + "loss_size": 0.09780682176351548, + "loss_pdi": 0.36214256286621094, + "loss_ee": 0.43477209806442263, + "loss_delivery": 0.23339744359254838, + "loss_biodist": 0.4091412305831909, + "loss_toxic": 0.051114612445235255 + }, + { + "loss": 1.6899895429611207, + "loss_size": 0.10551446527242661, + "loss_pdi": 0.3761424541473389, + "loss_ee": 0.46849397420883176, + "loss_delivery": 0.27737232744693757, + "loss_biodist": 0.4076103329658508, + "loss_toxic": 0.05485602542757988 + }, + { + "loss": 1.5929443359375, + "loss_size": 0.11571613550186158, + "loss_pdi": 0.352296257019043, + "loss_ee": 0.4432097375392914, + "loss_delivery": 0.2575030043721199, + "loss_biodist": 0.37044936418533325, + "loss_toxic": 0.05376977995038033 + } + ], + "val": [ + { + "loss": 22.81237284342448, + "loss_size": 13.806465148925781, + "loss_pdi": 1.229181448618571, + "loss_ee": 1.027739703655243, + "loss_delivery": 5.068911023437977, + "loss_biodist": 1.1279457012812297, + "loss_toxic": 0.5521289308865865, + "acc_pdi": 0.6526315789473685, + "acc_ee": 0.6947368421052632, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 15.498573303222656, + "loss_size": 6.912943522135417, + "loss_pdi": 1.151400883992513, + "loss_ee": 0.9848727782567342, + "loss_delivery": 4.852191311617692, + "loss_biodist": 1.106157898902893, + "loss_toxic": 0.4910069803396861, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 10.799206574757894, + "loss_size": 2.5314422051111856, + "loss_pdi": 1.0411948959032695, + "loss_ee": 0.9279763499895731, + "loss_delivery": 4.807806923985481, + "loss_biodist": 1.0742538372675579, + "loss_toxic": 0.4165322283903758, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 8.901862303415934, + "loss_size": 0.7716620067755381, + "loss_pdi": 1.01361749569575, + "loss_ee": 0.8802629311879476, + "loss_delivery": 4.857094804445903, + "loss_biodist": 1.0347116986910503, + "loss_toxic": 0.34451337655385333, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 8.569346030553183, + "loss_size": 0.41188229247927666, + "loss_pdi": 1.0502839088439941, + "loss_ee": 0.8726372222105662, + "loss_delivery": 4.940298028290272, + "loss_biodist": 1.0011842250823975, + "loss_toxic": 0.29306065539518994, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.03537893295288, + "loss_size": 0.5666823834180832, + "loss_pdi": 1.1037296056747437, + "loss_ee": 0.9048542082309723, + "loss_delivery": 5.218828019996484, + "loss_biodist": 0.9743464191754659, + "loss_toxic": 0.2669379909833272, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.051881631215414, + "loss_size": 0.5067511014640331, + "loss_pdi": 1.0797988673051198, + "loss_ee": 0.8918277323246002, + "loss_delivery": 5.376374647021294, + "loss_biodist": 0.9581284721692404, + "loss_toxic": 0.23900071531534195, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 8.965935707092285, + "loss_size": 0.4464708169301351, + "loss_pdi": 1.03824187318484, + "loss_ee": 0.8565650085608164, + "loss_delivery": 5.463352290292581, + "loss_biodist": 0.9425446192423502, + "loss_toxic": 0.21876097718874613, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.351192077000936, + "loss_size": 0.578731312106053, + "loss_pdi": 1.0086682240168254, + "loss_ee": 0.8167769958575567, + "loss_delivery": 5.8357385993003845, + "loss_biodist": 0.8966569105784098, + "loss_toxic": 0.21462025741736093, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.685971339543661, + "loss_size": 0.5851275982956091, + "loss_pdi": 0.9895619451999664, + "loss_ee": 0.8124870856602987, + "loss_delivery": 6.198981747031212, + "loss_biodist": 0.8602876861890157, + "loss_toxic": 0.2395249493420124, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.098868290583292, + "loss_size": 0.4912531226873398, + "loss_pdi": 0.9610133767127991, + "loss_ee": 0.8015020589033762, + "loss_delivery": 5.819371705253919, + "loss_biodist": 0.8127925594647726, + "loss_toxic": 0.21293580221633115, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.217247327168783, + "loss_size": 0.5547616928815842, + "loss_pdi": 0.9466870129108429, + "loss_ee": 0.80243648091952, + "loss_delivery": 5.907406737407048, + "loss_biodist": 0.7983624935150146, + "loss_toxic": 0.20759239544471106, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.307894190152487, + "loss_size": 0.6155262216925621, + "loss_pdi": 0.9505135516325632, + "loss_ee": 0.8020251393318176, + "loss_delivery": 5.944891105095546, + "loss_biodist": 0.7798070808251699, + "loss_toxic": 0.21513095125555992, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.434777816136679, + "loss_size": 0.5672734752297401, + "loss_pdi": 0.9817352195580801, + "loss_ee": 0.822968602180481, + "loss_delivery": 6.080470234155655, + "loss_biodist": 0.745076318581899, + "loss_toxic": 0.23725438863039017, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.355290253957113, + "loss_size": 0.5713960801561674, + "loss_pdi": 0.9858902891476949, + "loss_ee": 0.8192337850729624, + "loss_delivery": 6.011797075470288, + "loss_biodist": 0.724964420000712, + "loss_toxic": 0.24200843647122383, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.306842883427938, + "loss_size": 0.6075545425216357, + "loss_pdi": 0.9776454170544943, + "loss_ee": 0.7862655818462372, + "loss_delivery": 5.985511064529419, + "loss_biodist": 0.7076093653837839, + "loss_toxic": 0.24225737899541855, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.252086718877157, + "loss_size": 0.6051755348841349, + "loss_pdi": 0.9714606404304504, + "loss_ee": 0.7621750583251318, + "loss_delivery": 5.975689520438512, + "loss_biodist": 0.6939358512560526, + "loss_toxic": 0.24364992106954256, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6736842105263158, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.275835235913595, + "loss_size": 0.5751028036077818, + "loss_pdi": 0.9899951120217642, + "loss_ee": 0.768888125816981, + "loss_delivery": 5.988193516929944, + "loss_biodist": 0.6993262469768524, + "loss_toxic": 0.25432955101132393, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6842105263157895, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.215376615524292, + "loss_size": 0.5498512660463651, + "loss_pdi": 0.9988046189149221, + "loss_ee": 0.7566283419728279, + "loss_delivery": 5.962520445386569, + "loss_biodist": 0.6932495137055715, + "loss_toxic": 0.2543224884817998, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6947368421052632, + "acc_toxic": 0.8939393939393939 + }, + { + "loss": 9.17027481396993, + "loss_size": 0.5183901910980543, + "loss_pdi": 1.0107523500919342, + "loss_ee": 0.7627487083276113, + "loss_delivery": 5.940715471903483, + "loss_biodist": 0.6853441596031189, + "loss_toxic": 0.25232408692439395, + "acc_pdi": 0.6105263157894737, + "acc_ee": 0.6947368421052632, + "acc_toxic": 0.8939393939393939 + } + ] +} \ No newline at end of file diff --git a/models/finetune_cv/fold_1/model.pt b/models/finetune_cv/fold_1/model.pt new file mode 100644 index 0000000..9e3e48d Binary files /dev/null and b/models/finetune_cv/fold_1/model.pt differ diff --git a/models/finetune_cv/fold_2/history.json b/models/finetune_cv/fold_2/history.json new file mode 100644 index 0000000..a3ad4a4 --- /dev/null +++ b/models/finetune_cv/fold_2/history.json @@ -0,0 +1,468 @@ +{ + "train": [ + { + "loss": 23.856885592142742, + "loss_size": 17.954859733581543, + "loss_pdi": 1.3681265513102214, + "loss_ee": 1.068708558877309, + "loss_delivery": 1.6815399676561356, + "loss_biodist": 1.1385973294576008, + "loss_toxic": 0.6450536052385966 + }, + { + "loss": 15.18941100438436, + "loss_size": 10.012138843536377, + "loss_pdi": 1.2400256196657817, + "loss_ee": 0.9122016330560049, + "loss_delivery": 1.4215861509243648, + "loss_biodist": 1.0704677999019623, + "loss_toxic": 0.5329908380905787 + }, + { + "loss": 8.939830541610718, + "loss_size": 4.306200265884399, + "loss_pdi": 1.0988461375236511, + "loss_ee": 0.7973864078521729, + "loss_delivery": 1.312243824203809, + "loss_biodist": 0.9968815743923187, + "loss_toxic": 0.4282720486323039 + }, + { + "loss": 5.871116876602173, + "loss_size": 1.4981385171413422, + "loss_pdi": 0.9578391710917155, + "loss_ee": 0.7877674202124277, + "loss_delivery": 1.332635521888733, + "loss_biodist": 0.9970230559508005, + "loss_toxic": 0.29771314313014346 + }, + { + "loss": 4.818921804428101, + "loss_size": 0.6080186615387598, + "loss_pdi": 0.8167077898979187, + "loss_ee": 0.7845198512077332, + "loss_delivery": 1.421961595614751, + "loss_biodist": 0.9678046603997549, + "loss_toxic": 0.21990922341744104 + }, + { + "loss": 4.394984285036723, + "loss_size": 0.3149821311235428, + "loss_pdi": 0.7380032440026602, + "loss_ee": 0.7671190500259399, + "loss_delivery": 1.447837049762408, + "loss_biodist": 0.9290279944737753, + "loss_toxic": 0.19801472003261247 + }, + { + "loss": 3.884276866912842, + "loss_size": 0.26142628739277524, + "loss_pdi": 0.6927057504653931, + "loss_ee": 0.7489989002545675, + "loss_delivery": 1.1156622817118962, + "loss_biodist": 0.8776024182637533, + "loss_toxic": 0.1878812573850155 + }, + { + "loss": 3.709937810897827, + "loss_size": 0.37482741723457974, + "loss_pdi": 0.6356837352116903, + "loss_ee": 0.7201081812381744, + "loss_delivery": 1.0413622185587883, + "loss_biodist": 0.769838293393453, + "loss_toxic": 0.16811783549686274 + }, + { + "loss": 3.3223368724187217, + "loss_size": 0.2915526789923509, + "loss_pdi": 0.5975265900293986, + "loss_ee": 0.657735288143158, + "loss_delivery": 0.915939765671889, + "loss_biodist": 0.6836796899636587, + "loss_toxic": 0.17590284595886865 + }, + { + "loss": 3.22677751382192, + "loss_size": 0.30531836052735645, + "loss_pdi": 0.5498206416765848, + "loss_ee": 0.6144644021987915, + "loss_delivery": 1.0515401139855385, + "loss_biodist": 0.5715167572100958, + "loss_toxic": 0.13411726988852024 + }, + { + "loss": 2.898585855960846, + "loss_size": 0.26108303914467496, + "loss_pdi": 0.5470793843269348, + "loss_ee": 0.5787178675333658, + "loss_delivery": 0.9020876735448837, + "loss_biodist": 0.47204887370268506, + "loss_toxic": 0.13756892209251723 + }, + { + "loss": 2.6754438877105713, + "loss_size": 0.27526700248320896, + "loss_pdi": 0.5229905943075815, + "loss_ee": 0.5479903370141983, + "loss_delivery": 0.7761634774506092, + "loss_biodist": 0.44562476873397827, + "loss_toxic": 0.10740765929222107 + }, + { + "loss": 2.4977574348449707, + "loss_size": 0.2336499529580275, + "loss_pdi": 0.4902036637067795, + "loss_ee": 0.5169308086236318, + "loss_delivery": 0.7469637009004751, + "loss_biodist": 0.40839699904123944, + "loss_toxic": 0.10161229533453782 + }, + { + "loss": 2.384280482927958, + "loss_size": 0.2612900485595067, + "loss_pdi": 0.48895320296287537, + "loss_ee": 0.49676452577114105, + "loss_delivery": 0.6831860815485319, + "loss_biodist": 0.36347728967666626, + "loss_toxic": 0.090609318887194 + }, + { + "loss": 2.3147188226381936, + "loss_size": 0.25488365814089775, + "loss_pdi": 0.4683869779109955, + "loss_ee": 0.4875288059314092, + "loss_delivery": 0.6712647005915642, + "loss_biodist": 0.35311167935530346, + "loss_toxic": 0.0795429985349377 + }, + { + "loss": 2.2636735240618386, + "loss_size": 0.25234917054573697, + "loss_pdi": 0.48236118257045746, + "loss_ee": 0.4705241521199544, + "loss_delivery": 0.6308227330446243, + "loss_biodist": 0.3382392128308614, + "loss_toxic": 0.0893770344555378 + }, + { + "loss": 2.116434315840403, + "loss_size": 0.24041289339462915, + "loss_pdi": 0.45973017315069836, + "loss_ee": 0.4558122158050537, + "loss_delivery": 0.5561455090840658, + "loss_biodist": 0.3245606869459152, + "loss_toxic": 0.07977284273753564 + }, + { + "loss": 2.116472323735555, + "loss_size": 0.22036718018352985, + "loss_pdi": 0.4676543176174164, + "loss_ee": 0.43863776326179504, + "loss_delivery": 0.6074383656183878, + "loss_biodist": 0.3068434993426005, + "loss_toxic": 0.07553133244315784 + }, + { + "loss": 2.0553239782651267, + "loss_size": 0.2092201883594195, + "loss_pdi": 0.45897047221660614, + "loss_ee": 0.44677118460337323, + "loss_delivery": 0.559788204729557, + "loss_biodist": 0.31095271309216815, + "loss_toxic": 0.06962121867885192 + }, + { + "loss": 1.961161474386851, + "loss_size": 0.20881337051590285, + "loss_pdi": 0.45118602613608044, + "loss_ee": 0.42851509153842926, + "loss_delivery": 0.49867390592892963, + "loss_biodist": 0.3025246188044548, + "loss_toxic": 0.07144851268579562 + }, + { + "loss": 1.959239919980367, + "loss_size": 0.2156102918088436, + "loss_pdi": 0.4368931899468104, + "loss_ee": 0.42987839380900067, + "loss_delivery": 0.5220988343159357, + "loss_biodist": 0.2822929248213768, + "loss_toxic": 0.07246625237166882 + }, + { + "loss": 1.9401288827260335, + "loss_size": 0.21518264586726824, + "loss_pdi": 0.45129939913749695, + "loss_ee": 0.4190533608198166, + "loss_delivery": 0.5097754697004954, + "loss_biodist": 0.27901960412661236, + "loss_toxic": 0.06579839282979567 + } + ], + "val": [ + { + "loss": 17.15616934640067, + "loss_size": 12.466235705784388, + "loss_pdi": 1.2282596485955375, + "loss_ee": 1.1713778802326746, + "loss_delivery": 0.45601305684873034, + "loss_biodist": 1.3372918111937386, + "loss_toxic": 0.49699102129255024, + "acc_pdi": 0.6410256410256411, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 9.976628576006208, + "loss_size": 5.5366509301321845, + "loss_pdi": 1.1110572814941406, + "loss_ee": 1.2255418130329676, + "loss_delivery": 0.4425822538988931, + "loss_biodist": 1.2637801681246077, + "loss_toxic": 0.3970160186290741, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 6.142949717385428, + "loss_size": 2.012831313269479, + "loss_pdi": 0.9936460341726031, + "loss_ee": 1.2630535023553031, + "loss_delivery": 0.45085248563970837, + "loss_biodist": 1.1773428661482674, + "loss_toxic": 0.24522356688976288, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.523837975093296, + "loss_size": 0.6790421732834407, + "loss_pdi": 0.8449332543781826, + "loss_ee": 1.2828129359654017, + "loss_delivery": 0.47022694775036405, + "loss_biodist": 1.0964353680610657, + "loss_toxic": 0.15038717218807765, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 3.914861168180193, + "loss_size": 0.23712052990283286, + "loss_pdi": 0.739460038287299, + "loss_ee": 1.307591336114066, + "loss_delivery": 0.48760053728307995, + "loss_biodist": 1.0441895723342896, + "loss_toxic": 0.09889917501381465, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 3.827185903276716, + "loss_size": 0.19252057054213115, + "loss_pdi": 0.6754980896200452, + "loss_ee": 1.3178976603916712, + "loss_delivery": 0.52304607629776, + "loss_biodist": 1.0516272272382463, + "loss_toxic": 0.06659629621676036, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 3.7409281730651855, + "loss_size": 0.19949970713683537, + "loss_pdi": 0.63478513274874, + "loss_ee": 1.3695382390703474, + "loss_delivery": 0.5511368151221957, + "loss_biodist": 0.9442958320890155, + "loss_toxic": 0.04167233567152705, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.04153094972883, + "loss_size": 0.18925889155694417, + "loss_pdi": 0.6154808104038239, + "loss_ee": 1.4497678790773665, + "loss_delivery": 0.6580333965165275, + "loss_biodist": 1.0802616902760096, + "loss_toxic": 0.04872834203498704, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.0117951801845, + "loss_size": 0.261259101331234, + "loss_pdi": 0.6064834722450801, + "loss_ee": 1.4382908344268799, + "loss_delivery": 0.6485261533941541, + "loss_biodist": 1.013489259140832, + "loss_toxic": 0.04374648204871586, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.021411214556013, + "loss_size": 0.25430602048124584, + "loss_pdi": 0.6047579922846386, + "loss_ee": 1.4375121252877372, + "loss_delivery": 0.794183360678809, + "loss_biodist": 0.8880592542035239, + "loss_toxic": 0.04259242117404938, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.41025641025641024, + "acc_toxic": 1.0 + }, + { + "loss": 4.389863218579974, + "loss_size": 0.47541433785642895, + "loss_pdi": 0.6053804946797234, + "loss_ee": 1.5225199971880232, + "loss_delivery": 0.864760024206979, + "loss_biodist": 0.8813395031860897, + "loss_toxic": 0.04044891867254462, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.2736298356737406, + "loss_size": 0.4230478884918349, + "loss_pdi": 0.6034957447222301, + "loss_ee": 1.4814096518925257, + "loss_delivery": 0.9466098589556557, + "loss_biodist": 0.7808983538831983, + "loss_toxic": 0.03816834863807474, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4, + "acc_toxic": 1.0 + }, + { + "loss": 4.585977554321289, + "loss_size": 0.4787849709391594, + "loss_pdi": 0.6088685957448823, + "loss_ee": 1.5371792827333723, + "loss_delivery": 1.1909210256167821, + "loss_biodist": 0.7510551980563572, + "loss_toxic": 0.01916840286659343, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.317904029573713, + "loss_size": 0.46495475407157627, + "loss_pdi": 0.6024806403688022, + "loss_ee": 1.5152796847479684, + "loss_delivery": 1.0274277882916587, + "loss_biodist": 0.6894632875919342, + "loss_toxic": 0.018297837914100716, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.287515469959804, + "loss_size": 0.5376717469521931, + "loss_pdi": 0.5981708924685206, + "loss_ee": 1.4651154450007848, + "loss_delivery": 1.0260179340839386, + "loss_biodist": 0.6434041815144675, + "loss_toxic": 0.01713526714593172, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.38974358974358975, + "acc_toxic": 1.0 + }, + { + "loss": 4.683320045471191, + "loss_size": 0.5271316319704056, + "loss_pdi": 0.6275609007903508, + "loss_ee": 1.5593642677579607, + "loss_delivery": 1.2529506555625372, + "loss_biodist": 0.7023919905935015, + "loss_toxic": 0.013920623875622238, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.397426230566842, + "loss_size": 0.5106346564633506, + "loss_pdi": 0.6181366369128227, + "loss_ee": 1.5676180635179793, + "loss_delivery": 1.0354454517364502, + "loss_biodist": 0.6522124835423061, + "loss_toxic": 0.013378978440804141, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.41025641025641024, + "acc_toxic": 1.0 + }, + { + "loss": 4.412310838699341, + "loss_size": 0.5362906115395683, + "loss_pdi": 0.6130111419728824, + "loss_ee": 1.5929286650248937, + "loss_delivery": 1.001524874142238, + "loss_biodist": 0.6563895855631147, + "loss_toxic": 0.012165952806494065, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.78386572429112, + "loss_size": 0.5134438020842416, + "loss_pdi": 0.6321895952735629, + "loss_ee": 1.5953963484082903, + "loss_delivery": 1.3276494145393372, + "loss_biodist": 0.7046042127268655, + "loss_toxic": 0.010582369419613056, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.4205128205128205, + "acc_toxic": 1.0 + }, + { + "loss": 4.835996287209647, + "loss_size": 0.4968718098742621, + "loss_pdi": 0.6234481834939548, + "loss_ee": 1.5689009257725306, + "loss_delivery": 1.4453607542174203, + "loss_biodist": 0.6902376328195844, + "loss_toxic": 0.011176949766065394, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.41025641025641024, + "acc_toxic": 1.0 + }, + { + "loss": 4.646918603352138, + "loss_size": 0.49231292733124327, + "loss_pdi": 0.6135136048708644, + "loss_ee": 1.5504994562694006, + "loss_delivery": 1.3212553603308541, + "loss_biodist": 0.6578856153147561, + "loss_toxic": 0.01145164788301502, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.40512820512820513, + "acc_toxic": 1.0 + }, + { + "loss": 4.580311230250767, + "loss_size": 0.529256243790899, + "loss_pdi": 0.6140108300106866, + "loss_ee": 1.5573444025857108, + "loss_delivery": 1.2449579962662287, + "loss_biodist": 0.6246597383703504, + "loss_toxic": 0.01008199481293559, + "acc_pdi": 0.7076923076923077, + "acc_ee": 0.41025641025641024, + "acc_toxic": 1.0 + } + ] +} \ No newline at end of file diff --git a/models/finetune_cv/fold_2/model.pt b/models/finetune_cv/fold_2/model.pt new file mode 100644 index 0000000..8b04371 Binary files /dev/null and b/models/finetune_cv/fold_2/model.pt differ diff --git a/models/finetune_cv/fold_3/history.json b/models/finetune_cv/fold_3/history.json new file mode 100644 index 0000000..514aa3e --- /dev/null +++ b/models/finetune_cv/fold_3/history.json @@ -0,0 +1,573 @@ +{ + "train": [ + { + "loss": 21.41742353439331, + "loss_size": 15.701433181762695, + "loss_pdi": 1.3632826924324035, + "loss_ee": 1.0266466200351716, + "loss_delivery": 1.418760311603546, + "loss_biodist": 1.2302423000335694, + "loss_toxic": 0.6770588457584381 + }, + { + "loss": 9.57167477607727, + "loss_size": 4.903112936019897, + "loss_pdi": 1.140204393863678, + "loss_ee": 1.0074746072292329, + "loss_delivery": 1.1362140402197838, + "loss_biodist": 1.0133467674255372, + "loss_toxic": 0.3713219165802002 + }, + { + "loss": 5.485435938835144, + "loss_size": 1.4459647357463836, + "loss_pdi": 0.9086057424545289, + "loss_ee": 0.9852841377258301, + "loss_delivery": 1.0618216753005982, + "loss_biodist": 0.8613051772117615, + "loss_toxic": 0.22245449870824813 + }, + { + "loss": 4.065439391136169, + "loss_size": 0.45657781660556795, + "loss_pdi": 0.7768359065055848, + "loss_ee": 0.953633052110672, + "loss_delivery": 0.9996832102537155, + "loss_biodist": 0.7116856783628464, + "loss_toxic": 0.16702365390956403 + }, + { + "loss": 3.687756299972534, + "loss_size": 0.30417054146528244, + "loss_pdi": 0.7158299326896668, + "loss_ee": 0.9035742998123169, + "loss_delivery": 1.060212180018425, + "loss_biodist": 0.5556510210037231, + "loss_toxic": 0.1483182568103075 + }, + { + "loss": 3.536502242088318, + "loss_size": 0.3545849896967411, + "loss_pdi": 0.6672206580638885, + "loss_ee": 0.8711013793945312, + "loss_delivery": 1.0780200093984604, + "loss_biodist": 0.4400379478931427, + "loss_toxic": 0.12553719095885754 + }, + { + "loss": 3.289654517173767, + "loss_size": 0.3215350516140461, + "loss_pdi": 0.6445666253566742, + "loss_ee": 0.8839016199111939, + "loss_delivery": 0.9652104169130326, + "loss_biodist": 0.35911705493927004, + "loss_toxic": 0.11532373651862145 + }, + { + "loss": 3.199695384502411, + "loss_size": 0.29793725311756136, + "loss_pdi": 0.6349938422441482, + "loss_ee": 0.8508742034435273, + "loss_delivery": 0.9878654226660728, + "loss_biodist": 0.3253925606608391, + "loss_toxic": 0.10263208523392678 + }, + { + "loss": 2.981464409828186, + "loss_size": 0.2926298946142197, + "loss_pdi": 0.6041542202234268, + "loss_ee": 0.8202637135982513, + "loss_delivery": 0.9215006068348884, + "loss_biodist": 0.2569382354617119, + "loss_toxic": 0.08597766645252705 + }, + { + "loss": 2.7703840017318724, + "loss_size": 0.30560431331396104, + "loss_pdi": 0.5837061107158661, + "loss_ee": 0.7886367738246918, + "loss_delivery": 0.763620425760746, + "loss_biodist": 0.23922762870788575, + "loss_toxic": 0.08958881739526987 + }, + { + "loss": 2.6885447978973387, + "loss_size": 0.28905282765626905, + "loss_pdi": 0.5664507508277893, + "loss_ee": 0.7639070689678192, + "loss_delivery": 0.7704478114843368, + "loss_biodist": 0.21979134678840637, + "loss_toxic": 0.07889490202069283 + }, + { + "loss": 2.591668117046356, + "loss_size": 0.2613472960889339, + "loss_pdi": 0.5540884166955948, + "loss_ee": 0.7465697586536407, + "loss_delivery": 0.7534924671053886, + "loss_biodist": 0.20271824076771736, + "loss_toxic": 0.07345191687345505 + }, + { + "loss": 2.482152557373047, + "loss_size": 0.2550207316875458, + "loss_pdi": 0.5377364099025727, + "loss_ee": 0.7093640863895416, + "loss_delivery": 0.715790644288063, + "loss_biodist": 0.1968037411570549, + "loss_toxic": 0.0674369728192687 + }, + { + "loss": 2.4853516697883604, + "loss_size": 0.25445577800273894, + "loss_pdi": 0.5285622417926789, + "loss_ee": 0.7045736134052276, + "loss_delivery": 0.765293450653553, + "loss_biodist": 0.16704678535461426, + "loss_toxic": 0.06541981641203165 + }, + { + "loss": 2.325911021232605, + "loss_size": 0.22311154529452323, + "loss_pdi": 0.5131498754024506, + "loss_ee": 0.6778043508529663, + "loss_delivery": 0.6851547978818416, + "loss_biodist": 0.16465196907520294, + "loss_toxic": 0.06203848272562027 + }, + { + "loss": 2.213776695728302, + "loss_size": 0.24956582188606263, + "loss_pdi": 0.48536362051963805, + "loss_ee": 0.6791646689176559, + "loss_delivery": 0.5861033886671067, + "loss_biodist": 0.15450086519122125, + "loss_toxic": 0.059078306704759595 + }, + { + "loss": 2.3086095094680785, + "loss_size": 0.20203103460371494, + "loss_pdi": 0.4988987535238266, + "loss_ee": 0.6785910665988922, + "loss_delivery": 0.7291694968938828, + "loss_biodist": 0.1462075024843216, + "loss_toxic": 0.05371163971722126 + }, + { + "loss": 2.0882336378097532, + "loss_size": 0.2287739872932434, + "loss_pdi": 0.5072675496339798, + "loss_ee": 0.6701794564723969, + "loss_delivery": 0.4832353606820107, + "loss_biodist": 0.1421804867684841, + "loss_toxic": 0.05659680655226111 + }, + { + "loss": 2.0332674741744996, + "loss_size": 0.22676872164011003, + "loss_pdi": 0.4669553279876709, + "loss_ee": 0.6482236534357071, + "loss_delivery": 0.5079896807670593, + "loss_biodist": 0.13781072497367858, + "loss_toxic": 0.04551935400813818 + }, + { + "loss": 1.9842296838760376, + "loss_size": 0.2007827118039131, + "loss_pdi": 0.4655294865369797, + "loss_ee": 0.6295293152332306, + "loss_delivery": 0.5032630048692226, + "loss_biodist": 0.12692373394966125, + "loss_toxic": 0.05820149295032025 + }, + { + "loss": 1.9574703454971314, + "loss_size": 0.20575413331389428, + "loss_pdi": 0.4662397414445877, + "loss_ee": 0.6207350552082062, + "loss_delivery": 0.4868774816393852, + "loss_biodist": 0.1334032252430916, + "loss_toxic": 0.04446075968444348 + }, + { + "loss": 1.8693695425987245, + "loss_size": 0.2003278151154518, + "loss_pdi": 0.4481669098138809, + "loss_ee": 0.6228156566619873, + "loss_delivery": 0.4377666234970093, + "loss_biodist": 0.1222815040498972, + "loss_toxic": 0.03801098903641105 + }, + { + "loss": 1.9101393103599549, + "loss_size": 0.22527543231844901, + "loss_pdi": 0.4501156389713287, + "loss_ee": 0.5992490768432617, + "loss_delivery": 0.4722588837146759, + "loss_biodist": 0.12016024515032768, + "loss_toxic": 0.04308005180209875 + }, + { + "loss": 1.8186616897583008, + "loss_size": 0.20228504091501237, + "loss_pdi": 0.4353774756193161, + "loss_ee": 0.595091101527214, + "loss_delivery": 0.42390944324433805, + "loss_biodist": 0.12379681393504142, + "loss_toxic": 0.03820184739306569 + }, + { + "loss": 1.753528320789337, + "loss_size": 0.168972497433424, + "loss_pdi": 0.4384701639413834, + "loss_ee": 0.5804536670446396, + "loss_delivery": 0.4066555552184582, + "loss_biodist": 0.12041406258940697, + "loss_toxic": 0.038562366552650926 + }, + { + "loss": 1.752294898033142, + "loss_size": 0.16247735619544984, + "loss_pdi": 0.4280344098806381, + "loss_ee": 0.5781133621931076, + "loss_delivery": 0.4378485083580017, + "loss_biodist": 0.1110315527766943, + "loss_toxic": 0.03478973265737295 + }, + { + "loss": 1.8295514345169068, + "loss_size": 0.1906499370932579, + "loss_pdi": 0.42695977091789244, + "loss_ee": 0.6036634147167206, + "loss_delivery": 0.4585780970752239, + "loss_biodist": 0.10937272906303405, + "loss_toxic": 0.04032747393939644 + } + ], + "val": [ + { + "loss": 12.90866756439209, + "loss_size": 8.614875793457031, + "loss_pdi": 1.3393473625183105, + "loss_ee": 0.7915782928466797, + "loss_delivery": 0.35056664049625397, + "loss_biodist": 1.2413995265960693, + "loss_toxic": 0.5708996057510376, + "acc_pdi": 0.09803921568627451, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 6.046682596206665, + "loss_size": 2.181392192840576, + "loss_pdi": 1.010448306798935, + "loss_ee": 0.7943653464317322, + "loss_delivery": 0.3778253495693207, + "loss_biodist": 1.1818514466285706, + "loss_toxic": 0.5007997751235962, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 4.130645751953125, + "loss_size": 0.4065478593111038, + "loss_pdi": 0.7734719514846802, + "loss_ee": 0.7856209874153137, + "loss_delivery": 0.4249718487262726, + "loss_biodist": 1.1932182908058167, + "loss_toxic": 0.5468149781227112, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.74097740650177, + "loss_size": 0.05154523253440857, + "loss_pdi": 0.6415763646364212, + "loss_ee": 0.7410732209682465, + "loss_delivery": 0.38766030967235565, + "loss_biodist": 1.2818186283111572, + "loss_toxic": 0.6373037025332451, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.6484163999557495, + "loss_size": 0.05830332264304161, + "loss_pdi": 0.5795184522867203, + "loss_ee": 0.6541054546833038, + "loss_delivery": 0.40608713030815125, + "loss_biodist": 1.2502482235431671, + "loss_toxic": 0.7001538649201393, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.8285857439041138, + "loss_size": 0.07835924997925758, + "loss_pdi": 0.5752626657485962, + "loss_ee": 0.7419937252998352, + "loss_delivery": 0.382533997297287, + "loss_biodist": 1.3310475945472717, + "loss_toxic": 0.7193883210420609, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.6414456367492676, + "loss_size": 0.07554847374558449, + "loss_pdi": 0.6094752848148346, + "loss_ee": 0.7212981283664703, + "loss_delivery": 0.4330318123102188, + "loss_biodist": 1.1887046694755554, + "loss_toxic": 0.6133871823549271, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.537736415863037, + "loss_size": 0.10301502048969269, + "loss_pdi": 0.5892050117254257, + "loss_ee": 0.6979120671749115, + "loss_delivery": 0.5159651935100555, + "loss_biodist": 1.064111590385437, + "loss_toxic": 0.5675273537635803, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.6027194261550903, + "loss_size": 0.16463171318173409, + "loss_pdi": 0.5593049973249435, + "loss_ee": 0.6619580686092377, + "loss_delivery": 0.538649171590805, + "loss_biodist": 1.0742176473140717, + "loss_toxic": 0.6039578504860401, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.762178421020508, + "loss_size": 0.22026687115430832, + "loss_pdi": 0.5431422591209412, + "loss_ee": 0.7102161943912506, + "loss_delivery": 0.44816476106643677, + "loss_biodist": 1.172889918088913, + "loss_toxic": 0.6674983687698841, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.7515294551849365, + "loss_size": 0.2188272960484028, + "loss_pdi": 0.5569919049739838, + "loss_ee": 0.6500461399555206, + "loss_delivery": 0.4802835136651993, + "loss_biodist": 1.140358328819275, + "loss_toxic": 0.7050221972167492, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.47284197807312, + "loss_size": 0.19455180317163467, + "loss_pdi": 0.5266519337892532, + "loss_ee": 0.6175068914890289, + "loss_delivery": 0.5526535362005234, + "loss_biodist": 0.9517507255077362, + "loss_toxic": 0.6297270879149437, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.9374399185180664, + "loss_size": 0.32401855289936066, + "loss_pdi": 0.5553185790777206, + "loss_ee": 0.6632668077945709, + "loss_delivery": 0.48028646409511566, + "loss_biodist": 1.1518925726413727, + "loss_toxic": 0.7626568526029587, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.7636899948120117, + "loss_size": 0.2452084794640541, + "loss_pdi": 0.5189685672521591, + "loss_ee": 0.6509725153446198, + "loss_delivery": 0.38138364255428314, + "loss_biodist": 1.1877531707286835, + "loss_toxic": 0.7794036716222763, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8235294117647058, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.837575674057007, + "loss_size": 0.3500683009624481, + "loss_pdi": 0.5655115246772766, + "loss_ee": 0.6301239728927612, + "loss_delivery": 0.47451916337013245, + "loss_biodist": 1.069214403629303, + "loss_toxic": 0.7481381297111511, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8431372549019608, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.834155559539795, + "loss_size": 0.27750423550605774, + "loss_pdi": 0.513676255941391, + "loss_ee": 0.6291456520557404, + "loss_delivery": 0.40621335804462433, + "loss_biodist": 1.1680629253387451, + "loss_toxic": 0.8395530804991722, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.8235294117647058, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.6893343925476074, + "loss_size": 0.34515636414289474, + "loss_pdi": 0.5478167533874512, + "loss_ee": 0.6345725357532501, + "loss_delivery": 0.42669548094272614, + "loss_biodist": 1.0210922062397003, + "loss_toxic": 0.7140011191368103, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 4.00758171081543, + "loss_size": 0.3066108226776123, + "loss_pdi": 0.5529182702302933, + "loss_ee": 0.6956813335418701, + "loss_delivery": 0.4755028784275055, + "loss_biodist": 1.152823954820633, + "loss_toxic": 0.8240445479750633, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.974515676498413, + "loss_size": 0.3346070274710655, + "loss_pdi": 0.5447590947151184, + "loss_ee": 0.6450685858726501, + "loss_delivery": 0.4816073626279831, + "loss_biodist": 1.1390976309776306, + "loss_toxic": 0.8293759152293205, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.880821943283081, + "loss_size": 0.282451294362545, + "loss_pdi": 0.5469185262918472, + "loss_ee": 0.6417804062366486, + "loss_delivery": 0.46756358444690704, + "loss_biodist": 1.127053290605545, + "loss_toxic": 0.8150547966361046, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.8784565925598145, + "loss_size": 0.2880236804485321, + "loss_pdi": 0.5311962813138962, + "loss_ee": 0.6260144710540771, + "loss_delivery": 0.4471806138753891, + "loss_biodist": 1.1587725579738617, + "loss_toxic": 0.8272688835859299, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.9609978199005127, + "loss_size": 0.3277027904987335, + "loss_pdi": 0.5816705077886581, + "loss_ee": 0.6382596492767334, + "loss_delivery": 0.42079347372055054, + "loss_biodist": 1.1735960245132446, + "loss_toxic": 0.8189755231142044, + "acc_pdi": 0.9019607843137255, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.8955435752868652, + "loss_size": 0.28625721484422684, + "loss_pdi": 0.5423711687326431, + "loss_ee": 0.6477845013141632, + "loss_delivery": 0.43728773295879364, + "loss_biodist": 1.1806198358535767, + "loss_toxic": 0.801223024725914, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 4.007414698600769, + "loss_size": 0.3294255882501602, + "loss_pdi": 0.5471038520336151, + "loss_ee": 0.6502591967582703, + "loss_delivery": 0.4436507970094681, + "loss_biodist": 1.202763706445694, + "loss_toxic": 0.8342117220163345, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 4.050978183746338, + "loss_size": 0.39942242205142975, + "loss_pdi": 0.5590354651212692, + "loss_ee": 0.6409733295440674, + "loss_delivery": 0.45262810587882996, + "loss_biodist": 1.1648951172828674, + "loss_toxic": 0.8340234756469727, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.86995267868042, + "loss_size": 0.2950260043144226, + "loss_pdi": 0.5328812152147293, + "loss_ee": 0.6263198852539062, + "loss_delivery": 0.45980168879032135, + "loss_biodist": 1.1475371420383453, + "loss_toxic": 0.808386467397213, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + }, + { + "loss": 3.8100762367248535, + "loss_size": 0.2826509475708008, + "loss_pdi": 0.5344351083040237, + "loss_ee": 0.6322177648544312, + "loss_delivery": 0.4711516797542572, + "loss_biodist": 1.1184114813804626, + "loss_toxic": 0.7712092772126198, + "acc_pdi": 0.8823529411764706, + "acc_ee": 0.803921568627451, + "acc_toxic": 0.851063829787234 + } + ] +} \ No newline at end of file diff --git a/models/finetune_cv/fold_3/model.pt b/models/finetune_cv/fold_3/model.pt new file mode 100644 index 0000000..a55c2e7 Binary files /dev/null and b/models/finetune_cv/fold_3/model.pt differ diff --git a/models/finetune_cv/fold_4/history.json b/models/finetune_cv/fold_4/history.json new file mode 100644 index 0000000..a4cd136 --- /dev/null +++ b/models/finetune_cv/fold_4/history.json @@ -0,0 +1,405 @@ +{ + "train": [ + { + "loss": 17.82338151064786, + "loss_size": 12.698930350216953, + "loss_pdi": 1.2952325452457776, + "loss_ee": 1.044789281758395, + "loss_delivery": 0.894258591261777, + "loss_biodist": 1.3406665433536877, + "loss_toxic": 0.5495036569508639 + }, + { + "loss": 6.6719701940363105, + "loss_size": 2.361687348647551, + "loss_pdi": 0.9808245301246643, + "loss_ee": 0.9910268241708929, + "loss_delivery": 0.8210993910377676, + "loss_biodist": 1.184872735630382, + "loss_toxic": 0.3324593657797033 + }, + { + "loss": 4.197543209249323, + "loss_size": 0.3305926668373021, + "loss_pdi": 0.8021349581805143, + "loss_ee": 0.9533808014609597, + "loss_delivery": 0.85791384361007, + "loss_biodist": 1.0029216788031838, + "loss_toxic": 0.2505992312322963 + }, + { + "loss": 3.817467537793246, + "loss_size": 0.3109576465053992, + "loss_pdi": 0.7077692205255682, + "loss_ee": 0.9259536266326904, + "loss_delivery": 0.8590928573500026, + "loss_biodist": 0.7878239534117959, + "loss_toxic": 0.22587016597390175 + }, + { + "loss": 3.426318342035467, + "loss_size": 0.30332070047205145, + "loss_pdi": 0.6738776402039961, + "loss_ee": 0.8633853034539656, + "loss_delivery": 0.810848053206097, + "loss_biodist": 0.6030096492984078, + "loss_toxic": 0.17187697711316022 + }, + { + "loss": 3.243085037578236, + "loss_size": 0.2643032656474547, + "loss_pdi": 0.680456202138554, + "loss_ee": 0.8441235260529951, + "loss_delivery": 0.790125925432552, + "loss_biodist": 0.4977728453549472, + "loss_toxic": 0.16630326211452484 + }, + { + "loss": 3.0106391689994116, + "loss_size": 0.2923726256598126, + "loss_pdi": 0.6462291912599043, + "loss_ee": 0.8259676423939791, + "loss_delivery": 0.7123563428494063, + "loss_biodist": 0.4204860031604767, + "loss_toxic": 0.11322732371362773 + }, + { + "loss": 2.6575805924155493, + "loss_size": 0.23369696668603204, + "loss_pdi": 0.5998414809053595, + "loss_ee": 0.7827209342609752, + "loss_delivery": 0.6251701191067696, + "loss_biodist": 0.324639000675895, + "loss_toxic": 0.09151209010319276 + }, + { + "loss": 2.6034729263999243, + "loss_size": 0.20161977207118814, + "loss_pdi": 0.5999934835867449, + "loss_ee": 0.7699762636964972, + "loss_delivery": 0.6784080802039667, + "loss_biodist": 0.2691292708570307, + "loss_toxic": 0.08434606292708353 + }, + { + "loss": 2.481573982672258, + "loss_size": 0.2506237829273397, + "loss_pdi": 0.5553325793959878, + "loss_ee": 0.702888163653287, + "loss_delivery": 0.6528384658423337, + "loss_biodist": 0.25332851166074927, + "loss_toxic": 0.06656241179867224 + }, + { + "loss": 2.328354239463806, + "loss_size": 0.20882186158136887, + "loss_pdi": 0.5448922569101508, + "loss_ee": 0.7051796858960931, + "loss_delivery": 0.5663092088970271, + "loss_biodist": 0.23630927367643875, + "loss_toxic": 0.06684201947328719 + }, + { + "loss": 2.1841621182181616, + "loss_size": 0.20124886184930801, + "loss_pdi": 0.5195452462543141, + "loss_ee": 0.6754790571602908, + "loss_delivery": 0.5004782859574665, + "loss_biodist": 0.2269973118196834, + "loss_toxic": 0.06041339209133929 + }, + { + "loss": 2.218748081814159, + "loss_size": 0.1988528309897943, + "loss_pdi": 0.5231576589020815, + "loss_ee": 0.6777869246222756, + "loss_delivery": 0.5498508228497072, + "loss_biodist": 0.20798503336581317, + "loss_toxic": 0.06111483716151931 + }, + { + "loss": 2.1877676248550415, + "loss_size": 0.20004721256819638, + "loss_pdi": 0.5174875286492434, + "loss_ee": 0.6804409514773976, + "loss_delivery": 0.5348355831070379, + "loss_biodist": 0.2000916667959907, + "loss_toxic": 0.054864687167785385 + }, + { + "loss": 2.0138474811207163, + "loss_size": 0.19752372259443457, + "loss_pdi": 0.4868703836744482, + "loss_ee": 0.6450781768018549, + "loss_delivery": 0.44826274839314545, + "loss_biodist": 0.18033211339603772, + "loss_toxic": 0.05578036225316199 + }, + { + "loss": 2.011049357327548, + "loss_size": 0.18426605652679096, + "loss_pdi": 0.4944283068180084, + "loss_ee": 0.6483827070756392, + "loss_delivery": 0.4382695244117217, + "loss_biodist": 0.19008470394394614, + "loss_toxic": 0.05561802129853855 + }, + { + "loss": 2.0496722134676846, + "loss_size": 0.18194164742122998, + "loss_pdi": 0.4917480132796548, + "loss_ee": 0.6405326967889612, + "loss_delivery": 0.5158521959727461, + "loss_biodist": 0.1772592243823138, + "loss_toxic": 0.042338407937098636 + }, + { + "loss": 1.9639968113465742, + "loss_size": 0.1595074331218546, + "loss_pdi": 0.48567668145353143, + "loss_ee": 0.630230272358114, + "loss_delivery": 0.4618720757690343, + "loss_biodist": 0.17742999304424634, + "loss_toxic": 0.0492804107171568 + }, + { + "loss": 1.9399811571294612, + "loss_size": 0.18978103656660428, + "loss_pdi": 0.48150675947015936, + "loss_ee": 0.6243780959736217, + "loss_delivery": 0.4384633018211885, + "loss_biodist": 0.16038654270497235, + "loss_toxic": 0.045465379022061825 + } + ], + "val": [ + { + "loss": 8.991046905517578, + "loss_size": 4.490478515625, + "loss_pdi": 0.9734575947125753, + "loss_ee": 0.9441032012303671, + "loss_delivery": 1.2459152142206829, + "loss_biodist": 0.9957353870073954, + "loss_toxic": 0.34135735034942627, + "acc_pdi": 0.8484848484848485, + "acc_ee": 0.8333333333333334, + "acc_toxic": 1.0 + }, + { + "loss": 3.4952890078226724, + "loss_size": 0.27914075056711835, + "loss_pdi": 0.6866898337999979, + "loss_ee": 0.7263152599334717, + "loss_delivery": 0.7405761281649271, + "loss_biodist": 0.9172844588756561, + "loss_toxic": 0.14528251190980276, + "acc_pdi": 0.8484848484848485, + "acc_ee": 0.8333333333333334, + "acc_toxic": 1.0 + }, + { + "loss": 2.8096923430760703, + "loss_size": 0.04479753350218137, + "loss_pdi": 0.48064420620600384, + "loss_ee": 0.6685001452763876, + "loss_delivery": 0.7262534300486246, + "loss_biodist": 0.827881266673406, + "loss_toxic": 0.06161577875415484, + "acc_pdi": 0.8484848484848485, + "acc_ee": 0.8333333333333334, + "acc_toxic": 1.0 + }, + { + "loss": 2.756531000137329, + "loss_size": 0.07622841248909633, + "loss_pdi": 0.463245431582133, + "loss_ee": 0.669317622979482, + "loss_delivery": 0.6688057581583658, + "loss_biodist": 0.8275523781776428, + "loss_toxic": 0.05138145387172699, + "acc_pdi": 0.8484848484848485, + "acc_ee": 0.8333333333333334, + "acc_toxic": 1.0 + }, + { + "loss": 2.9385202328364053, + "loss_size": 0.1038126324613889, + "loss_pdi": 0.46456684668858844, + "loss_ee": 0.705346941947937, + "loss_delivery": 0.7891722718874613, + "loss_biodist": 0.82027334968249, + "loss_toxic": 0.055348185201485954, + "acc_pdi": 0.8484848484848485, + "acc_ee": 0.8333333333333334, + "acc_toxic": 1.0 + }, + { + "loss": 2.8632373809814453, + "loss_size": 0.056236049781243004, + "loss_pdi": 0.44324543078740436, + "loss_ee": 0.6018350621064504, + "loss_delivery": 0.9420756896336874, + "loss_biodist": 0.755872001250585, + "loss_toxic": 0.06397318094968796, + "acc_pdi": 0.8484848484848485, + "acc_ee": 0.8333333333333334, + "acc_toxic": 1.0 + }, + { + "loss": 2.993934949239095, + "loss_size": 0.08202650646368663, + "loss_pdi": 0.508717954158783, + "loss_ee": 0.7680216828982035, + "loss_delivery": 0.8612345655759176, + "loss_biodist": 0.7118468681971232, + "loss_toxic": 0.0620873523876071, + "acc_pdi": 0.8484848484848485, + "acc_ee": 0.48484848484848486, + "acc_toxic": 1.0 + }, + { + "loss": 2.8506224155426025, + "loss_size": 0.03146359619374076, + "loss_pdi": 0.5297882954279581, + "loss_ee": 0.7625004649162292, + "loss_delivery": 0.802927960952123, + "loss_biodist": 0.6721820036570231, + "loss_toxic": 0.05176017774889866, + "acc_pdi": 0.8333333333333334, + "acc_ee": 0.4696969696969697, + "acc_toxic": 1.0 + }, + { + "loss": 2.8424178759256997, + "loss_size": 0.08800932268301646, + "loss_pdi": 0.540622721115748, + "loss_ee": 0.839806874593099, + "loss_delivery": 0.6362804472446442, + "loss_biodist": 0.6938393215338389, + "loss_toxic": 0.043859192014982305, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.4696969696969697, + "acc_toxic": 1.0 + }, + { + "loss": 2.9927186171213784, + "loss_size": 0.04866213048808277, + "loss_pdi": 0.5090581774711609, + "loss_ee": 0.8101900815963745, + "loss_delivery": 0.8959137797355652, + "loss_biodist": 0.697561984260877, + "loss_toxic": 0.031332316963622965, + "acc_pdi": 0.8333333333333334, + "acc_ee": 0.3787878787878788, + "acc_toxic": 1.0 + }, + { + "loss": 2.9314780632654824, + "loss_size": 0.06920161470770836, + "loss_pdi": 0.5283850431442261, + "loss_ee": 0.8220112522443136, + "loss_delivery": 0.7353424032529196, + "loss_biodist": 0.7471547623475393, + "loss_toxic": 0.029382963587219518, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.3787878787878788, + "acc_toxic": 1.0 + }, + { + "loss": 2.9559961954752603, + "loss_size": 0.05500282160937786, + "loss_pdi": 0.516402949889501, + "loss_ee": 0.8015920122464498, + "loss_delivery": 0.7959124445915222, + "loss_biodist": 0.7596821735302607, + "loss_toxic": 0.0274037744384259, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.3787878787878788, + "acc_toxic": 1.0 + }, + { + "loss": 2.9268057346343994, + "loss_size": 0.07737081746260326, + "loss_pdi": 0.5241717199484507, + "loss_ee": 0.835136612256368, + "loss_delivery": 0.7045041720072428, + "loss_biodist": 0.7396295169989268, + "loss_toxic": 0.0459929151305308, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.3787878787878788, + "acc_toxic": 1.0 + }, + { + "loss": 2.9870657920837402, + "loss_size": 0.08099805439511935, + "loss_pdi": 0.5164442459742228, + "loss_ee": 0.8418577512105306, + "loss_delivery": 0.7609673937161764, + "loss_biodist": 0.7428288658459982, + "loss_toxic": 0.04396952743021151, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.3787878787878788, + "acc_toxic": 1.0 + }, + { + "loss": 2.979916493097941, + "loss_size": 0.09317472080389659, + "loss_pdi": 0.514099915822347, + "loss_ee": 0.8726487557093302, + "loss_delivery": 0.6872047583262125, + "loss_biodist": 0.7711265037457148, + "loss_toxic": 0.04166170318300525, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.36363636363636365, + "acc_toxic": 1.0 + }, + { + "loss": 2.937371532122294, + "loss_size": 0.06330848174790542, + "loss_pdi": 0.5085045297940572, + "loss_ee": 0.8129515051841736, + "loss_delivery": 0.7455949584643046, + "loss_biodist": 0.7728735754887263, + "loss_toxic": 0.03413854034927984, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.3787878787878788, + "acc_toxic": 1.0 + }, + { + "loss": 3.0136941274007163, + "loss_size": 0.0849883034825325, + "loss_pdi": 0.51878755291303, + "loss_ee": 0.785295327504476, + "loss_delivery": 0.8036739627520243, + "loss_biodist": 0.7789742400248846, + "loss_toxic": 0.041974871419370174, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.3787878787878788, + "acc_toxic": 1.0 + }, + { + "loss": 2.9988585313161216, + "loss_size": 0.053379556785027184, + "loss_pdi": 0.5192934771378835, + "loss_ee": 0.838801383972168, + "loss_delivery": 0.7629345854123434, + "loss_biodist": 0.7728350758552551, + "loss_toxic": 0.05161439681736132, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.36363636363636365, + "acc_toxic": 1.0 + }, + { + "loss": 2.9807325998942056, + "loss_size": 0.055528260146578155, + "loss_pdi": 0.5009754200776418, + "loss_ee": 0.8306655287742615, + "loss_delivery": 0.7751240134239197, + "loss_biodist": 0.7733635157346725, + "loss_toxic": 0.045075801705631115, + "acc_pdi": 0.8181818181818182, + "acc_ee": 0.36363636363636365, + "acc_toxic": 1.0 + } + ] +} \ No newline at end of file diff --git a/models/finetune_cv/fold_4/model.pt b/models/finetune_cv/fold_4/model.pt new file mode 100644 index 0000000..a8c8de6 Binary files /dev/null and b/models/finetune_cv/fold_4/model.pt differ diff --git a/models/finetune_cv/test_results.json b/models/finetune_cv/test_results.json new file mode 100644 index 0000000..aa9bc22 --- /dev/null +++ b/models/finetune_cv/test_results.json @@ -0,0 +1,198 @@ +{ + "fold_results": [ + { + "fold_idx": 0, + "n_samples": 95, + "size": { + "n": 95, + "rmse": 0.632240295394253, + "mae": 0.4674149764211554, + "r2": -0.13819694158244777 + }, + "delivery": { + "n": 66, + "rmse": 1.3553486689823926, + "mae": 0.48114709207562334, + "r2": -0.008234200686471516 + }, + "pdi": { + "n": 95, + "accuracy": 0.6105263157894737 + }, + "ee": { + "n": 95, + "accuracy": 0.6631578947368421 + }, + "toxic": { + "n": 66, + "accuracy": 0.8939393939393939 + } + }, + { + "fold_idx": 1, + "n_samples": 195, + "size": { + "n": 193, + "rmse": 0.42622752144538556, + "mae": 0.24566329575573226, + "r2": 0.0482086242002292 + }, + "delivery": { + "n": 123, + "rmse": 0.742899240869997, + "mae": 0.5315999669170507, + "r2": -0.03140039086191759 + }, + "pdi": { + "n": 195, + "accuracy": 0.7076923076923077 + }, + "ee": { + "n": 195, + "accuracy": 0.4205128205128205 + }, + "toxic": { + "n": 123, + "accuracy": 1.0 + } + }, + { + "fold_idx": 2, + "n_samples": 51, + "size": { + "n": 51, + "rmse": 0.241909571406037, + "mae": 0.20043573192521638, + "r2": -0.43487628292073 + }, + "delivery": { + "n": 44, + "rmse": 0.7564153649581582, + "mae": 0.6047130756302398, + "r2": -0.4226486727361405 + }, + "pdi": { + "n": 51, + "accuracy": 0.8823529411764706 + }, + "ee": { + "n": 51, + "accuracy": 0.8431372549019608 + }, + "toxic": { + "n": 47, + "accuracy": 0.851063829787234 + } + }, + { + "fold_idx": 3, + "n_samples": 66, + "size": { + "n": 66, + "rmse": 0.2857872773679936, + "mae": 0.22075237649859805, + "r2": -0.5674047032859011 + }, + "delivery": { + "n": 62, + "rmse": 1.0291312965402932, + "mae": 0.7422042032328224, + "r2": -0.7148264932933832 + }, + "pdi": { + "n": 66, + "accuracy": 0.8484848484848485 + }, + "ee": { + "n": 66, + "accuracy": 0.18181818181818182 + }, + "toxic": { + "n": 62, + "accuracy": 1.0 + } + }, + { + "fold_idx": 4, + "n_samples": 27, + "size": { + "n": 27, + "rmse": 0.2271495001169846, + "mae": 0.18753767013549805, + "r2": -0.19441156195074893 + }, + "delivery": { + "n": 15, + "rmse": 1.993006453768918, + "mae": 1.3779302000999452, + "r2": -0.3411461507368889 + }, + "pdi": { + "n": 27, + "accuracy": 0.8888888888888888 + }, + "ee": { + "n": 27, + "accuracy": 0.5925925925925926 + }, + "toxic": { + "n": 15, + "accuracy": 1.0 + } + } + ], + "summary_stats": { + "size": { + "rmse_mean": 0.36266283314613074, + "rmse_std": 0.15203127472757474, + "r2_mean": -0.2573361731079197, + "r2_std": 0.2187118059634264 + }, + "delivery": { + "rmse_mean": 1.1753602050239518, + "rmse_std": 0.46580283242073095, + "r2_mean": -0.30365118166296035, + "r2_std": 0.2630677092396549 + }, + "pdi": { + "accuracy_mean": 0.7875890604063979, + "accuracy_std": 0.11016791908756088 + }, + "ee": { + "accuracy_mean": 0.5402437489124795, + "accuracy_std": 0.22467627690136344 + }, + "toxic": { + "accuracy_mean": 0.9490006447453256, + "accuracy_std": 0.06391582554207781 + } + }, + "overall": { + "size": { + "n_samples": 432, + "mse": 0.19167728610863985, + "rmse": 0.43780964597486866, + "mae": 0.2816500812768936, + "r2": -0.04410163027802061 + }, + "delivery": { + "n_samples": 310, + "mse": 1.095306046771274, + "rmse": 1.0465687014101244, + "mae": 0.6143080417337197, + "r2": -0.1024184074306409 + }, + "pdi": { + "n_samples": 434, + "accuracy": 0.7396313364055299 + }, + "ee": { + "n_samples": 434, + "accuracy": 0.4976958525345622 + }, + "toxic": { + "n_samples": 313, + "accuracy": 0.9552715654952076 + } + } +} \ No newline at end of file diff --git a/scripts/process_data_cv.py b/scripts/process_data_cv.py new file mode 100644 index 0000000..7be689e --- /dev/null +++ b/scripts/process_data_cv.py @@ -0,0 +1,226 @@ +"""内部数据 Cross-Validation 划分脚本:基于 Amine 的分组划分""" + +from pathlib import Path +from typing import List + +import numpy as np +import pandas as pd +import typer +from loguru import logger + +from lnp_ml.config import INTERIM_DATA_DIR, PROCESSED_DATA_DIR +from lnp_ml.dataset import ( + process_dataframe, + SMILES_COL, + COMP_COLS, + HELP_COLS, + TARGET_REGRESSION, + TARGET_CLASSIFICATION_PDI, + TARGET_CLASSIFICATION_EE, + TARGET_TOXIC, + TARGET_BIODIST, + get_phys_cols, + get_exp_cols, +) + +app = typer.Typer() + + +def amine_based_cv_split( + df: pd.DataFrame, + n_folds: int = 5, + seed: int = 42, + amine_col: str = "Amine", +) -> List[dict]: + """ + 基于 Amine 列进行 Cross-Validation 划分。 + + 步骤: + 1. 按 amine_col 分组 + 2. 打乱分组顺序 + 3. 将分组 round-robin 分配到 n_folds 个容器 + 4. 对于每个 fold i: + - validation = container[i] + - test = container[(i+1) % n_folds] + - train = 其余所有 + + Args: + df: 输入 DataFrame + n_folds: 折数 + seed: 随机种子 + amine_col: 用于分组的列名 + + Returns: + List of dicts,每个 dict 包含 train_df, val_df, test_df + """ + # 获取唯一的 amine 并打乱 + unique_amines = df[amine_col].unique() + rng = np.random.RandomState(seed) + rng.shuffle(unique_amines) + + logger.info(f"Found {len(unique_amines)} unique amines") + + # Round-robin 分配到 n_folds 个容器 + containers = [[] for _ in range(n_folds)] + for i, amine in enumerate(unique_amines): + containers[i % n_folds].append(amine) + + # 打印每个容器的大小 + for i, container in enumerate(containers): + container_samples = df[df[amine_col].isin(container)] + logger.info(f" Container {i}: {len(container)} amines, {len(container_samples)} samples") + + # 生成每个 fold 的数据 + fold_splits = [] + for i in range(n_folds): + val_amines = set(containers[i]) + test_amines = set(containers[(i + 1) % n_folds]) + train_amines = set() + for j in range(n_folds): + if j != i and j != (i + 1) % n_folds: + train_amines.update(containers[j]) + + train_df = df[df[amine_col].isin(train_amines)].reset_index(drop=True) + val_df = df[df[amine_col].isin(val_amines)].reset_index(drop=True) + test_df = df[df[amine_col].isin(test_amines)].reset_index(drop=True) + + fold_splits.append({ + "train": train_df, + "val": val_df, + "test": test_df, + }) + + logger.info( + f"Fold {i}: train={len(train_df)} ({len(train_amines)} amines), " + f"val={len(val_df)} ({len(val_amines)} amines), " + f"test={len(test_df)} ({len(test_amines)} amines)" + ) + + return fold_splits + + +@app.command() +def main( + input_path: Path = INTERIM_DATA_DIR / "internal_corrected.csv", + output_dir: Path = PROCESSED_DATA_DIR / "cv", + n_folds: int = 5, + seed: int = 42, + amine_col: str = "Amine", +): + """ + 基于 Amine 分组进行 Cross-Validation 数据划分。 + + 采用类似 scaffold splitting 的思路,将相同 Amine 的数据放在同一组, + 确保训练集和测试集之间没有 Amine 泄露。 + + 划分比例约为 train:val:test ≈ 3:1:1 + + 输出结构: + - processed/cv/fold_0/train.parquet + - processed/cv/fold_0/val.parquet + - processed/cv/fold_0/test.parquet + - processed/cv/fold_1/... + - processed/cv/feature_columns.txt + """ + logger.info(f"Loading data from {input_path}") + df = pd.read_csv(input_path) + logger.info(f"Loaded {len(df)} samples") + + # 检查 amine 列是否存在 + if amine_col not in df.columns: + logger.error(f"Column '{amine_col}' not found in data. Available columns: {list(df.columns)}") + raise typer.Exit(1) + + # 处理数据(列对齐、one-hot 生成等) + logger.info("Processing dataframe...") + df = process_dataframe(df) + + # 确保 Amine 列仍然存在(process_dataframe 可能不会保留它) + # 重新加载原始数据获取 Amine 列 + original_df = pd.read_csv(input_path) + if amine_col in original_df.columns and amine_col not in df.columns: + df[amine_col] = original_df[amine_col].values + + # 定义要保留的列 + phys_cols = get_phys_cols() + exp_cols = get_exp_cols() + + keep_cols = ( + [SMILES_COL] + + COMP_COLS + + phys_cols + + HELP_COLS + + exp_cols + + TARGET_REGRESSION + + TARGET_CLASSIFICATION_PDI + + TARGET_CLASSIFICATION_EE + + [TARGET_TOXIC] + + TARGET_BIODIST + ) + + # 只保留存在的列 + keep_cols = [c for c in keep_cols if c in df.columns] + + # 进行 CV 划分 + logger.info(f"\nPerforming {n_folds}-fold amine-based CV split (seed={seed})...") + fold_splits = amine_based_cv_split(df, n_folds=n_folds, seed=seed, amine_col=amine_col) + + # 保存每个 fold + output_dir.mkdir(parents=True, exist_ok=True) + + for i, split in enumerate(fold_splits): + fold_dir = output_dir / f"fold_{i}" + fold_dir.mkdir(parents=True, exist_ok=True) + + # 只保留需要的列 + train_df = split["train"][keep_cols].reset_index(drop=True) + val_df = split["val"][keep_cols].reset_index(drop=True) + test_df = split["test"][keep_cols].reset_index(drop=True) + + # 保存 + train_df.to_parquet(fold_dir / "train.parquet", index=False) + val_df.to_parquet(fold_dir / "val.parquet", index=False) + test_df.to_parquet(fold_dir / "test.parquet", index=False) + + logger.success(f"Saved fold {i} to {fold_dir}") + + # 保存列名配置 + config_path = output_dir / "feature_columns.txt" + with open(config_path, "w") as f: + f.write("# Feature columns configuration\n\n") + f.write(f"# SMILES\n{SMILES_COL}\n\n") + f.write(f"# comp token [{len(COMP_COLS)}]\n") + f.write("\n".join(COMP_COLS) + "\n\n") + f.write(f"# phys token [{len(phys_cols)}]\n") + f.write("\n".join(phys_cols) + "\n\n") + f.write(f"# help token [{len(HELP_COLS)}]\n") + f.write("\n".join(HELP_COLS) + "\n\n") + f.write(f"# exp token [{len(exp_cols)}]\n") + f.write("\n".join(exp_cols) + "\n\n") + f.write("# Targets\n") + f.write("## Regression\n") + f.write("\n".join(TARGET_REGRESSION) + "\n") + f.write("## PDI classification\n") + f.write("\n".join(TARGET_CLASSIFICATION_PDI) + "\n") + f.write("## EE classification\n") + f.write("\n".join(TARGET_CLASSIFICATION_EE) + "\n") + f.write("## Toxic\n") + f.write(f"{TARGET_TOXIC}\n") + f.write("## Biodistribution\n") + f.write("\n".join(TARGET_BIODIST) + "\n") + + logger.success(f"Saved feature config to {config_path}") + + # 打印汇总 + logger.info("\n" + "=" * 60) + logger.info("CV DATA PROCESSING COMPLETE") + logger.info("=" * 60) + logger.info(f"Output directory: {output_dir}") + logger.info(f"Number of folds: {n_folds}") + logger.info(f"Splitting method: Amine-based (column: {amine_col})") + logger.info(f"Random seed: {seed}") + + +if __name__ == "__main__": + app() + diff --git a/scripts/process_external_cv.py b/scripts/process_external_cv.py index 9f414b1..e0cfb14 100644 --- a/scripts/process_external_cv.py +++ b/scripts/process_external_cv.py @@ -151,18 +151,18 @@ def get_feature_columns() -> List[str]: @app.command() def main( data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON", - output_dir: Path = PROCESSED_DATA_DIR / "cv", + output_dir: Path = PROCESSED_DATA_DIR / "pretrain_cv", n_folds: int = 5, ): """ 处理 cross-validation 数据,生成模型所需的 parquet 文件。 输出结构: - - processed/cv/fold_0/train.parquet - - processed/cv/fold_0/valid.parquet - - processed/cv/fold_0/test.parquet - - processed/cv/fold_1/... - - processed/cv/feature_columns.txt + - processed/pretrain_cv/fold_0/train.parquet + - processed/pretrain_cv/fold_0/valid.parquet + - processed/pretrain_cv/fold_0/test.parquet + - processed/pretrain_cv/fold_1/... + - processed/pretrain_cv/feature_columns.txt """ logger.info(f"Processing CV data from {data_dir}")