mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
Add CV
This commit is contained in:
parent
6773929ea2
commit
a2bfb26dfc
15
Makefile
15
Makefile
@ -73,6 +73,11 @@ data: requirements
|
|||||||
data_pretrain: requirements
|
data_pretrain: requirements
|
||||||
$(PYTHON_INTERPRETER) scripts/process_external.py
|
$(PYTHON_INTERPRETER) scripts/process_external.py
|
||||||
|
|
||||||
|
## Process CV data for cross-validation pretrain (external/all_amine_split_for_LiON -> processed/cv)
|
||||||
|
.PHONY: data_cv
|
||||||
|
data_cv: requirements
|
||||||
|
$(PYTHON_INTERPRETER) scripts/process_data_cv.py
|
||||||
|
|
||||||
# MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder
|
# MPNN 支持:使用 USE_MPNN=1 启用 MPNN encoder
|
||||||
# 例如:make pretrain USE_MPNN=1
|
# 例如:make pretrain USE_MPNN=1
|
||||||
MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,)
|
MPNN_FLAG = $(if $(USE_MPNN),--use-mpnn,)
|
||||||
@ -91,6 +96,16 @@ pretrain: requirements
|
|||||||
test_pretrain: requirements
|
test_pretrain: requirements
|
||||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG)
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain test $(MPNN_FLAG)
|
||||||
|
|
||||||
|
## Pretrain with cross-validation (5-fold)
|
||||||
|
.PHONY: pretrain_cv
|
||||||
|
pretrain_cv: requirements
|
||||||
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv main $(MPNN_FLAG)
|
||||||
|
|
||||||
|
## Evaluate CV pretrain models on test sets
|
||||||
|
.PHONY: test_cv
|
||||||
|
test_cv: requirements
|
||||||
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.pretrain_cv test $(MPNN_FLAG)
|
||||||
|
|
||||||
## Train model (multi-task, from scratch)
|
## Train model (multi-task, from scratch)
|
||||||
.PHONY: train
|
.PHONY: train
|
||||||
train: requirements
|
train: requirements
|
||||||
|
|||||||
55
data/processed/cv/feature_columns.txt
Normal file
55
data/processed/cv/feature_columns.txt
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
smiles
|
||||||
|
Cationic_Lipid_to_mRNA_weight_ratio
|
||||||
|
Cationic_Lipid_Mol_Ratio
|
||||||
|
Phospholipid_Mol_Ratio
|
||||||
|
Cholesterol_Mol_Ratio
|
||||||
|
PEG_Lipid_Mol_Ratio
|
||||||
|
Purity_Pure
|
||||||
|
Purity_Crude
|
||||||
|
Mix_type_Microfluidic
|
||||||
|
Mix_type_Pipetting
|
||||||
|
Cargo_type_mRNA
|
||||||
|
Cargo_type_pDNA
|
||||||
|
Cargo_type_siRNA
|
||||||
|
Target_or_delivered_gene_FFL
|
||||||
|
Target_or_delivered_gene_Peptide_barcode
|
||||||
|
Target_or_delivered_gene_hEPO
|
||||||
|
Target_or_delivered_gene_FVII
|
||||||
|
Target_or_delivered_gene_GFP
|
||||||
|
Helper_lipid_ID_DOPE
|
||||||
|
Helper_lipid_ID_DOTAP
|
||||||
|
Helper_lipid_ID_DSPC
|
||||||
|
Helper_lipid_ID_MDOA
|
||||||
|
Model_type_A549
|
||||||
|
Model_type_BDMC
|
||||||
|
Model_type_BMDM
|
||||||
|
Model_type_HBEC_ALI
|
||||||
|
Model_type_HEK293T
|
||||||
|
Model_type_HeLa
|
||||||
|
Model_type_IGROV1
|
||||||
|
Model_type_Mouse
|
||||||
|
Model_type_RAW264p7
|
||||||
|
Delivery_target_body
|
||||||
|
Delivery_target_dendritic_cell
|
||||||
|
Delivery_target_generic_cell
|
||||||
|
Delivery_target_liver
|
||||||
|
Delivery_target_lung
|
||||||
|
Delivery_target_lung_epithelium
|
||||||
|
Delivery_target_macrophage
|
||||||
|
Delivery_target_muscle
|
||||||
|
Delivery_target_spleen
|
||||||
|
Route_of_administration_in_vitro
|
||||||
|
Route_of_administration_intramuscular
|
||||||
|
Route_of_administration_intratracheal
|
||||||
|
Route_of_administration_intravenous
|
||||||
|
Batch_or_individual_or_barcoded_Barcoded
|
||||||
|
Batch_or_individual_or_barcoded_Individual
|
||||||
|
Value_name_log_luminescence
|
||||||
|
Value_name_luminescence
|
||||||
|
Value_name_FFL_silencing
|
||||||
|
Value_name_Peptide_abundance
|
||||||
|
Value_name_hEPO
|
||||||
|
Value_name_FVII_silencing
|
||||||
|
Value_name_GFP_delivery
|
||||||
|
Value_name_Discretized_luminescence
|
||||||
|
quantified_delivery
|
||||||
BIN
data/processed/cv/fold_0/test.parquet
Normal file
BIN
data/processed/cv/fold_0/test.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_0/train.parquet
Normal file
BIN
data/processed/cv/fold_0/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_0/valid.parquet
Normal file
BIN
data/processed/cv/fold_0/valid.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_1/test.parquet
Normal file
BIN
data/processed/cv/fold_1/test.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_1/train.parquet
Normal file
BIN
data/processed/cv/fold_1/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_1/valid.parquet
Normal file
BIN
data/processed/cv/fold_1/valid.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_2/test.parquet
Normal file
BIN
data/processed/cv/fold_2/test.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_2/train.parquet
Normal file
BIN
data/processed/cv/fold_2/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_2/valid.parquet
Normal file
BIN
data/processed/cv/fold_2/valid.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_3/test.parquet
Normal file
BIN
data/processed/cv/fold_3/test.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_3/train.parquet
Normal file
BIN
data/processed/cv/fold_3/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_3/valid.parquet
Normal file
BIN
data/processed/cv/fold_3/valid.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_4/test.parquet
Normal file
BIN
data/processed/cv/fold_4/test.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_4/train.parquet
Normal file
BIN
data/processed/cv/fold_4/train.parquet
Normal file
Binary file not shown.
BIN
data/processed/cv/fold_4/valid.parquet
Normal file
BIN
data/processed/cv/fold_4/valid.parquet
Normal file
Binary file not shown.
639
lnp_ml/modeling/pretrain_cv.py
Normal file
639
lnp_ml/modeling/pretrain_cv.py
Normal file
@ -0,0 +1,639 @@
|
|||||||
|
"""基于 Cross-Validation 的预训练脚本"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from loguru import logger
|
||||||
|
from tqdm import tqdm
|
||||||
|
from sklearn.metrics import mean_squared_error, r2_score
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from lnp_ml.config import MODELS_DIR, PROCESSED_DATA_DIR
|
||||||
|
from lnp_ml.dataset import ExternalDeliveryDataset, collate_fn
|
||||||
|
|
||||||
|
|
||||||
|
# MPNN ensemble 默认路径
|
||||||
|
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||||||
|
|
||||||
|
|
||||||
|
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||||
|
"""
|
||||||
|
自动查找 MPNN ensemble 的 model.pt 文件。
|
||||||
|
"""
|
||||||
|
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||||||
|
if not model_paths:
|
||||||
|
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||||||
|
return [str(p) for p in model_paths]
|
||||||
|
|
||||||
|
|
||||||
|
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
"""早停机制"""
|
||||||
|
|
||||||
|
def __init__(self, patience: int = 10, min_delta: float = 0.0):
|
||||||
|
self.patience = patience
|
||||||
|
self.min_delta = min_delta
|
||||||
|
self.counter = 0
|
||||||
|
self.best_loss = float("inf")
|
||||||
|
|
||||||
|
def __call__(self, val_loss: float) -> bool:
|
||||||
|
if val_loss < self.best_loss - self.min_delta:
|
||||||
|
self.best_loss = val_loss
|
||||||
|
self.counter = 0
|
||||||
|
return False
|
||||||
|
self.counter += 1
|
||||||
|
return self.counter >= self.patience
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_cache(model: nn.Module, smiles_list: List[str], batch_size: int = 256) -> None:
|
||||||
|
"""预热 RDKit 特征缓存"""
|
||||||
|
unique_smiles = list(set(smiles_list))
|
||||||
|
logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...")
|
||||||
|
|
||||||
|
for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"):
|
||||||
|
batch = unique_smiles[i:i + batch_size]
|
||||||
|
model.rdkit_encoder(batch)
|
||||||
|
|
||||||
|
logger.success(f"Cache warmup complete. Cached {len(model.rdkit_encoder._cache)} SMILES.")
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch_delivery(
|
||||||
|
model: nn.Module,
|
||||||
|
loader: DataLoader,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
device: torch.device,
|
||||||
|
epoch: int = 0,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""单个 epoch 的训练(仅 delivery 任务)"""
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
n_samples = 0
|
||||||
|
|
||||||
|
pbar = tqdm(loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
|
||||||
|
for batch in pbar:
|
||||||
|
smiles = batch["smiles"]
|
||||||
|
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||||
|
targets = batch["targets"]["delivery"].to(device)
|
||||||
|
mask = batch["mask"]["delivery"].to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
pred = model.forward_delivery(smiles, tabular).squeeze(-1)
|
||||||
|
|
||||||
|
if mask.any():
|
||||||
|
loss = nn.functional.mse_loss(pred[mask], targets[mask])
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item() * mask.sum().item()
|
||||||
|
n_samples += mask.sum().item()
|
||||||
|
|
||||||
|
pbar.set_postfix({"loss": total_loss / max(n_samples, 1)})
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(n_samples, 1)
|
||||||
|
return {"loss": avg_loss, "n_samples": n_samples}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate_delivery(
|
||||||
|
model: nn.Module,
|
||||||
|
loader: DataLoader,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""验证(仅 delivery 任务)"""
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
n_samples = 0
|
||||||
|
all_preds = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
for batch in loader:
|
||||||
|
smiles = batch["smiles"]
|
||||||
|
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||||
|
targets = batch["targets"]["delivery"].to(device)
|
||||||
|
mask = batch["mask"]["delivery"].to(device)
|
||||||
|
|
||||||
|
pred = model.forward_delivery(smiles, tabular).squeeze(-1)
|
||||||
|
|
||||||
|
if mask.any():
|
||||||
|
loss = nn.functional.mse_loss(pred[mask], targets[mask])
|
||||||
|
total_loss += loss.item() * mask.sum().item()
|
||||||
|
n_samples += mask.sum().item()
|
||||||
|
all_preds.extend(pred[mask].cpu().numpy().tolist())
|
||||||
|
all_targets.extend(targets[mask].cpu().numpy().tolist())
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(n_samples, 1)
|
||||||
|
|
||||||
|
# 计算额外指标
|
||||||
|
metrics = {"loss": avg_loss, "n_samples": n_samples}
|
||||||
|
if len(all_preds) > 0:
|
||||||
|
all_preds = np.array(all_preds)
|
||||||
|
all_targets = np.array(all_targets)
|
||||||
|
metrics["rmse"] = float(np.sqrt(mean_squared_error(all_targets, all_preds)))
|
||||||
|
metrics["r2"] = float(r2_score(all_targets, all_preds))
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def train_fold(
|
||||||
|
fold_idx: int,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
model: nn.Module,
|
||||||
|
device: torch.device,
|
||||||
|
output_dir: Path,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-5,
|
||||||
|
epochs: int = 50,
|
||||||
|
patience: int = 10,
|
||||||
|
config: Optional[Dict] = None,
|
||||||
|
) -> Dict:
|
||||||
|
"""训练单个 fold"""
|
||||||
|
logger.info(f"\n{'='*60}")
|
||||||
|
logger.info(f"Training Fold {fold_idx}")
|
||||||
|
logger.info(f"{'='*60}")
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, mode="min", factor=0.5, patience=5
|
||||||
|
)
|
||||||
|
early_stopping = EarlyStopping(patience=patience)
|
||||||
|
|
||||||
|
best_val_loss = float("inf")
|
||||||
|
best_state = None
|
||||||
|
history = []
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
train_metrics = train_epoch_delivery(model, train_loader, optimizer, device, epoch)
|
||||||
|
val_metrics = validate_delivery(model, val_loader, device)
|
||||||
|
|
||||||
|
current_lr = optimizer.param_groups[0]["lr"]
|
||||||
|
logger.info(
|
||||||
|
f"Fold {fold_idx} Epoch {epoch+1}/{epochs} | "
|
||||||
|
f"Train Loss: {train_metrics['loss']:.4f} | "
|
||||||
|
f"Val Loss: {val_metrics['loss']:.4f} | "
|
||||||
|
f"Val RMSE: {val_metrics.get('rmse', 0):.4f} | "
|
||||||
|
f"Val R²: {val_metrics.get('r2', 0):.4f} | "
|
||||||
|
f"LR: {current_lr:.2e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
history.append({
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"train_loss": train_metrics["loss"],
|
||||||
|
"val_loss": val_metrics["loss"],
|
||||||
|
"val_rmse": val_metrics.get("rmse", 0),
|
||||||
|
"val_r2": val_metrics.get("r2", 0),
|
||||||
|
"lr": current_lr,
|
||||||
|
})
|
||||||
|
|
||||||
|
scheduler.step(val_metrics["loss"])
|
||||||
|
|
||||||
|
if val_metrics["loss"] < best_val_loss:
|
||||||
|
best_val_loss = val_metrics["loss"]
|
||||||
|
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||||
|
logger.info(f" -> New best val_loss: {best_val_loss:.4f}")
|
||||||
|
|
||||||
|
if early_stopping(val_metrics["loss"]):
|
||||||
|
logger.info(f"Early stopping at epoch {epoch+1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 保存最佳模型
|
||||||
|
fold_output_dir = output_dir / f"fold_{fold_idx}"
|
||||||
|
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint_path = fold_output_dir / "model.pt"
|
||||||
|
torch.save({
|
||||||
|
"model_state_dict": best_state,
|
||||||
|
"config": config,
|
||||||
|
"best_val_loss": best_val_loss,
|
||||||
|
"fold_idx": fold_idx,
|
||||||
|
}, checkpoint_path)
|
||||||
|
logger.success(f"Saved fold {fold_idx} model to {checkpoint_path}")
|
||||||
|
|
||||||
|
# 保存训练历史
|
||||||
|
history_path = fold_output_dir / "history.json"
|
||||||
|
with open(history_path, "w") as f:
|
||||||
|
json.dump(history, f, indent=2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"fold_idx": fold_idx,
|
||||||
|
"best_val_loss": best_val_loss,
|
||||||
|
"best_val_rmse": history[-1]["val_rmse"] if history else 0,
|
||||||
|
"best_val_r2": history[-1]["val_r2"] if history else 0,
|
||||||
|
"epochs_trained": len(history),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(
|
||||||
|
d_model: int = 256,
|
||||||
|
num_heads: int = 8,
|
||||||
|
n_attn_layers: int = 4,
|
||||||
|
fusion_strategy: str = "attention",
|
||||||
|
head_hidden_dim: int = 128,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
use_mpnn: bool = False,
|
||||||
|
mpnn_ensemble_paths: Optional[List[str]] = None,
|
||||||
|
mpnn_device: str = "cpu",
|
||||||
|
) -> nn.Module:
|
||||||
|
"""创建模型实例"""
|
||||||
|
if use_mpnn:
|
||||||
|
return LNPModel(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
mpnn_ensemble_paths=mpnn_ensemble_paths,
|
||||||
|
mpnn_device=mpnn_device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return LNPModelWithoutMPNN(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
data_dir: Path = PROCESSED_DATA_DIR / "cv",
|
||||||
|
output_dir: Path = MODELS_DIR / "pretrain_cv",
|
||||||
|
# 模型参数
|
||||||
|
d_model: int = 256,
|
||||||
|
num_heads: int = 8,
|
||||||
|
n_attn_layers: int = 4,
|
||||||
|
fusion_strategy: str = "attention",
|
||||||
|
head_hidden_dim: int = 128,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
# MPNN 参数
|
||||||
|
use_mpnn: bool = False,
|
||||||
|
mpnn_checkpoint: Optional[str] = None,
|
||||||
|
mpnn_ensemble_paths: Optional[str] = None,
|
||||||
|
mpnn_device: str = "cpu",
|
||||||
|
# 训练参数
|
||||||
|
batch_size: int = 64,
|
||||||
|
lr: float = 1e-4,
|
||||||
|
weight_decay: float = 1e-5,
|
||||||
|
epochs: int = 50,
|
||||||
|
patience: int = 10,
|
||||||
|
# 设备
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
基于 5-fold Cross-Validation 预训练 LNP 模型(仅 delivery 任务)。
|
||||||
|
|
||||||
|
每个 fold 单独训练一个模型,保存到 output_dir/fold_x/model.pt。
|
||||||
|
使用 --use-mpnn 启用 MPNN encoder。
|
||||||
|
"""
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
device = torch.device(device)
|
||||||
|
|
||||||
|
# 解析 MPNN 参数
|
||||||
|
mpnn_paths = None
|
||||||
|
if use_mpnn:
|
||||||
|
if mpnn_ensemble_paths:
|
||||||
|
mpnn_paths = mpnn_ensemble_paths.split(",")
|
||||||
|
logger.info(f"Using provided MPNN ensemble paths: {len(mpnn_paths)} models")
|
||||||
|
elif mpnn_checkpoint:
|
||||||
|
mpnn_paths = [mpnn_checkpoint]
|
||||||
|
logger.info(f"Using single MPNN checkpoint: {mpnn_checkpoint}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Auto-detecting MPNN ensemble from {DEFAULT_MPNN_ENSEMBLE_DIR}")
|
||||||
|
mpnn_paths = find_mpnn_ensemble_paths()
|
||||||
|
logger.info(f"Found {len(mpnn_paths)} MPNN models")
|
||||||
|
|
||||||
|
# 查找所有 fold 目录
|
||||||
|
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
|
||||||
|
|
||||||
|
if not fold_dirs:
|
||||||
|
logger.error(f"No fold_* directories found in {data_dir}")
|
||||||
|
logger.info("Please run 'make data_cv' first to process CV data.")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
logger.info(f"Found {len(fold_dirs)} folds: {[d.name for d in fold_dirs]}")
|
||||||
|
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
config = {
|
||||||
|
"d_model": d_model,
|
||||||
|
"num_heads": num_heads,
|
||||||
|
"n_attn_layers": n_attn_layers,
|
||||||
|
"fusion_strategy": fusion_strategy,
|
||||||
|
"head_hidden_dim": head_hidden_dim,
|
||||||
|
"dropout": dropout,
|
||||||
|
"use_mpnn": use_mpnn,
|
||||||
|
"mpnn_ensemble_paths": mpnn_paths,
|
||||||
|
"lr": lr,
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"epochs": epochs,
|
||||||
|
"patience": patience,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 保存配置
|
||||||
|
config_path = output_dir / "config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
logger.info(f"Saved config to {config_path}")
|
||||||
|
|
||||||
|
# 训练每个 fold
|
||||||
|
fold_results = []
|
||||||
|
all_smiles = set()
|
||||||
|
|
||||||
|
# 先收集所有 SMILES 用于 cache warmup
|
||||||
|
for fold_dir in fold_dirs:
|
||||||
|
for split in ["train", "valid"]:
|
||||||
|
df = pd.read_parquet(fold_dir / f"{split}.parquet")
|
||||||
|
all_smiles.update(df["smiles"].tolist())
|
||||||
|
|
||||||
|
for fold_dir in fold_dirs:
|
||||||
|
fold_idx = int(fold_dir.name.split("_")[1])
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
train_df = pd.read_parquet(fold_dir / "train.parquet")
|
||||||
|
val_df = pd.read_parquet(fold_dir / "valid.parquet")
|
||||||
|
|
||||||
|
logger.info(f"\nFold {fold_idx}: train={len(train_df)}, val={len(val_df)}")
|
||||||
|
|
||||||
|
# 创建 Dataset 和 DataLoader
|
||||||
|
train_dataset = ExternalDeliveryDataset(train_df)
|
||||||
|
val_dataset = ExternalDeliveryDataset(val_df)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建新模型(每个 fold 独立初始化)
|
||||||
|
model = create_model(
|
||||||
|
d_model=d_model,
|
||||||
|
num_heads=num_heads,
|
||||||
|
n_attn_layers=n_attn_layers,
|
||||||
|
fusion_strategy=fusion_strategy,
|
||||||
|
head_hidden_dim=head_hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
use_mpnn=use_mpnn,
|
||||||
|
mpnn_ensemble_paths=mpnn_paths,
|
||||||
|
mpnn_device=device.type,
|
||||||
|
)
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# 第一个 fold 时做 cache warmup
|
||||||
|
if fold_idx == 0:
|
||||||
|
warmup_cache(model, list(all_smiles), batch_size=256)
|
||||||
|
|
||||||
|
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
# 训练
|
||||||
|
result = train_fold(
|
||||||
|
fold_idx=fold_idx,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
model=model,
|
||||||
|
device=device,
|
||||||
|
output_dir=output_dir,
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
epochs=epochs,
|
||||||
|
patience=patience,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
fold_results.append(result)
|
||||||
|
|
||||||
|
# 汇总结果
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("CROSS-VALIDATION TRAINING COMPLETE")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
val_losses = [r["best_val_loss"] for r in fold_results]
|
||||||
|
val_rmses = [r["best_val_rmse"] for r in fold_results]
|
||||||
|
val_r2s = [r["best_val_r2"] for r in fold_results]
|
||||||
|
|
||||||
|
logger.info(f"\n[Per-Fold Results]")
|
||||||
|
for r in fold_results:
|
||||||
|
logger.info(
|
||||||
|
f" Fold {r['fold_idx']}: "
|
||||||
|
f"Val Loss={r['best_val_loss']:.4f}, "
|
||||||
|
f"RMSE={r['best_val_rmse']:.4f}, "
|
||||||
|
f"R²={r['best_val_r2']:.4f}, "
|
||||||
|
f"Epochs={r['epochs_trained']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"\n[Summary Statistics]")
|
||||||
|
logger.info(f" Val Loss: {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}")
|
||||||
|
logger.info(f" Val RMSE: {np.mean(val_rmses):.4f} ± {np.std(val_rmses):.4f}")
|
||||||
|
logger.info(f" Val R²: {np.mean(val_r2s):.4f} ± {np.std(val_r2s):.4f}")
|
||||||
|
|
||||||
|
# 保存 CV 结果
|
||||||
|
cv_results = {
|
||||||
|
"fold_results": fold_results,
|
||||||
|
"summary": {
|
||||||
|
"val_loss_mean": float(np.mean(val_losses)),
|
||||||
|
"val_loss_std": float(np.std(val_losses)),
|
||||||
|
"val_rmse_mean": float(np.mean(val_rmses)),
|
||||||
|
"val_rmse_std": float(np.std(val_rmses)),
|
||||||
|
"val_r2_mean": float(np.mean(val_r2s)),
|
||||||
|
"val_r2_std": float(np.std(val_r2s)),
|
||||||
|
},
|
||||||
|
"config": config,
|
||||||
|
}
|
||||||
|
|
||||||
|
results_path = output_dir / "cv_results.json"
|
||||||
|
with open(results_path, "w") as f:
|
||||||
|
json.dump(cv_results, f, indent=2)
|
||||||
|
logger.success(f"Saved CV results to {results_path}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def test(
|
||||||
|
data_dir: Path = PROCESSED_DATA_DIR / "cv",
|
||||||
|
model_dir: Path = MODELS_DIR / "pretrain_cv",
|
||||||
|
output_path: Path = MODELS_DIR / "pretrain_cv" / "test_results.json",
|
||||||
|
batch_size: int = 64,
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
在测试集上评估 CV 预训练模型。
|
||||||
|
|
||||||
|
使用每个 fold 的模型在对应的测试集上评估。
|
||||||
|
"""
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
device = torch.device(device)
|
||||||
|
|
||||||
|
# 查找所有 fold 目录
|
||||||
|
fold_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("fold_")])
|
||||||
|
|
||||||
|
if not fold_dirs:
|
||||||
|
logger.error(f"No fold_* directories found in {data_dir}")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
logger.info(f"Found {len(fold_dirs)} folds")
|
||||||
|
|
||||||
|
fold_results = []
|
||||||
|
all_preds = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
for fold_dir in fold_dirs:
|
||||||
|
fold_idx = int(fold_dir.name.split("_")[1])
|
||||||
|
model_path = model_dir / f"fold_{fold_idx}" / "model.pt"
|
||||||
|
test_path = fold_dir / "test.parquet"
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
logger.warning(f"Fold {fold_idx}: model not found at {model_path}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not test_path.exists():
|
||||||
|
logger.warning(f"Fold {fold_idx}: test data not found at {test_path}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
checkpoint = torch.load(model_path, map_location=device)
|
||||||
|
config = checkpoint["config"]
|
||||||
|
|
||||||
|
use_mpnn = config.get("use_mpnn", False)
|
||||||
|
mpnn_paths = config.get("mpnn_ensemble_paths")
|
||||||
|
|
||||||
|
if use_mpnn and not mpnn_paths:
|
||||||
|
mpnn_paths = find_mpnn_ensemble_paths()
|
||||||
|
|
||||||
|
model = create_model(
|
||||||
|
d_model=config["d_model"],
|
||||||
|
num_heads=config["num_heads"],
|
||||||
|
n_attn_layers=config["n_attn_layers"],
|
||||||
|
fusion_strategy=config["fusion_strategy"],
|
||||||
|
head_hidden_dim=config["head_hidden_dim"],
|
||||||
|
dropout=config["dropout"],
|
||||||
|
use_mpnn=use_mpnn,
|
||||||
|
mpnn_ensemble_paths=mpnn_paths,
|
||||||
|
mpnn_device=device.type,
|
||||||
|
)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 加载测试数据
|
||||||
|
test_df = pd.read_parquet(test_path)
|
||||||
|
test_dataset = ExternalDeliveryDataset(test_df)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# 评估
|
||||||
|
fold_preds = []
|
||||||
|
fold_targets = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in test_loader:
|
||||||
|
smiles = batch["smiles"]
|
||||||
|
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||||||
|
targets = batch["targets"]["delivery"].to(device)
|
||||||
|
mask = batch["mask"]["delivery"].to(device)
|
||||||
|
|
||||||
|
pred = model.forward_delivery(smiles, tabular).squeeze(-1)
|
||||||
|
|
||||||
|
if mask.any():
|
||||||
|
fold_preds.extend(pred[mask].cpu().numpy().tolist())
|
||||||
|
fold_targets.extend(targets[mask].cpu().numpy().tolist())
|
||||||
|
|
||||||
|
# 计算 fold 指标
|
||||||
|
fold_preds = np.array(fold_preds)
|
||||||
|
fold_targets = np.array(fold_targets)
|
||||||
|
|
||||||
|
mse = float(mean_squared_error(fold_targets, fold_preds))
|
||||||
|
rmse = float(np.sqrt(mse))
|
||||||
|
r2 = float(r2_score(fold_targets, fold_preds))
|
||||||
|
mae = float(np.mean(np.abs(fold_targets - fold_preds)))
|
||||||
|
corr = float(np.corrcoef(fold_targets, fold_preds)[0, 1])
|
||||||
|
|
||||||
|
fold_results.append({
|
||||||
|
"fold_idx": fold_idx,
|
||||||
|
"n_samples": len(fold_preds),
|
||||||
|
"mse": mse,
|
||||||
|
"rmse": rmse,
|
||||||
|
"mae": mae,
|
||||||
|
"r2": r2,
|
||||||
|
"correlation": corr,
|
||||||
|
})
|
||||||
|
|
||||||
|
all_preds.extend(fold_preds.tolist())
|
||||||
|
all_targets.extend(fold_targets.tolist())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Fold {fold_idx}: n={len(fold_preds)}, "
|
||||||
|
f"RMSE={rmse:.4f}, R²={r2:.4f}, MAE={mae:.4f}, Corr={corr:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算整体指标
|
||||||
|
all_preds = np.array(all_preds)
|
||||||
|
all_targets = np.array(all_targets)
|
||||||
|
|
||||||
|
overall_mse = float(mean_squared_error(all_targets, all_preds))
|
||||||
|
overall_rmse = float(np.sqrt(overall_mse))
|
||||||
|
overall_r2 = float(r2_score(all_targets, all_preds))
|
||||||
|
overall_mae = float(np.mean(np.abs(all_targets - all_preds)))
|
||||||
|
overall_corr = float(np.corrcoef(all_targets, all_preds)[0, 1])
|
||||||
|
|
||||||
|
# 汇总统计
|
||||||
|
rmses = [r["rmse"] for r in fold_results]
|
||||||
|
r2s = [r["r2"] for r in fold_results]
|
||||||
|
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("CV TEST EVALUATION RESULTS")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
logger.info(f"\n[Summary Statistics (across {len(fold_results)} folds)]")
|
||||||
|
logger.info(f" RMSE: {np.mean(rmses):.4f} ± {np.std(rmses):.4f}")
|
||||||
|
logger.info(f" R²: {np.mean(r2s):.4f} ± {np.std(r2s):.4f}")
|
||||||
|
|
||||||
|
logger.info(f"\n[Overall (all {len(all_preds)} samples pooled)]")
|
||||||
|
logger.info(f" RMSE: {overall_rmse:.4f}")
|
||||||
|
logger.info(f" R²: {overall_r2:.4f}")
|
||||||
|
logger.info(f" MAE: {overall_mae:.4f}")
|
||||||
|
logger.info(f" Correlation: {overall_corr:.4f}")
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
results = {
|
||||||
|
"fold_results": fold_results,
|
||||||
|
"summary_stats": {
|
||||||
|
"rmse_mean": float(np.mean(rmses)),
|
||||||
|
"rmse_std": float(np.std(rmses)),
|
||||||
|
"r2_mean": float(np.mean(r2s)),
|
||||||
|
"r2_std": float(np.std(r2s)),
|
||||||
|
},
|
||||||
|
"overall": {
|
||||||
|
"n_samples": len(all_preds),
|
||||||
|
"mse": overall_mse,
|
||||||
|
"rmse": overall_rmse,
|
||||||
|
"mae": overall_mae,
|
||||||
|
"r2": overall_r2,
|
||||||
|
"correlation": overall_corr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
logger.success(f"Saved test results to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
|
|
||||||
2538
models/history.json
2538
models/history.json
File diff suppressed because it is too large
Load Diff
BIN
models/model.pt
BIN
models/model.pt
Binary file not shown.
21
models/pretrain_cv/config.json
Normal file
21
models/pretrain_cv/config.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"d_model": 256,
|
||||||
|
"num_heads": 8,
|
||||||
|
"n_attn_layers": 4,
|
||||||
|
"fusion_strategy": "attention",
|
||||||
|
"head_hidden_dim": 128,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"use_mpnn": true,
|
||||||
|
"mpnn_ensemble_paths": [
|
||||||
|
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_0/fold_0/model_0/model.pt",
|
||||||
|
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_1/fold_0/model_0/model.pt",
|
||||||
|
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_2/fold_0/model_0/model.pt",
|
||||||
|
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_3/fold_0/model_0/model.pt",
|
||||||
|
"/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/mpnn/all_amine_split_for_LiON/cv_4/fold_0/model_0/model.pt"
|
||||||
|
],
|
||||||
|
"lr": 0.0001,
|
||||||
|
"weight_decay": 1e-05,
|
||||||
|
"batch_size": 64,
|
||||||
|
"epochs": 50,
|
||||||
|
"patience": 10
|
||||||
|
}
|
||||||
154
models/pretrain_cv/fold_0/history.json
Normal file
154
models/pretrain_cv/fold_0/history.json
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"epoch": 1,
|
||||||
|
"train_loss": 0.7841362829904129,
|
||||||
|
"val_loss": 0.9230800325498075,
|
||||||
|
"val_rmse": 0.9607705443260125,
|
||||||
|
"val_r2": 0.20751899565844845,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 2,
|
||||||
|
"train_loss": 0.6463805462366321,
|
||||||
|
"val_loss": 0.9147881584284676,
|
||||||
|
"val_rmse": 0.9564455868655657,
|
||||||
|
"val_r2": 0.2146377239324061,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 3,
|
||||||
|
"train_loss": 0.5832159742587586,
|
||||||
|
"val_loss": 0.8564491166460212,
|
||||||
|
"val_rmse": 0.9254453497069719,
|
||||||
|
"val_r2": 0.2647228727252957,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 4,
|
||||||
|
"train_loss": 0.5554495169353261,
|
||||||
|
"val_loss": 0.9521924139543092,
|
||||||
|
"val_rmse": 0.9758034723689635,
|
||||||
|
"val_r2": 0.18252548972094385,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 5,
|
||||||
|
"train_loss": 0.5284749799126208,
|
||||||
|
"val_loss": 0.9002334602240055,
|
||||||
|
"val_rmse": 0.9488063310427656,
|
||||||
|
"val_r2": 0.22713320447964735,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 6,
|
||||||
|
"train_loss": 0.5008455182478028,
|
||||||
|
"val_loss": 0.8869625280360692,
|
||||||
|
"val_rmse": 0.9417868917406855,
|
||||||
|
"val_r2": 0.23852651728338659,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 7,
|
||||||
|
"train_loss": 0.4812185961924912,
|
||||||
|
"val_loss": 0.8656982819227316,
|
||||||
|
"val_rmse": 0.9304290898141719,
|
||||||
|
"val_r2": 0.25678226981933505,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 8,
|
||||||
|
"train_loss": 0.46716415395603483,
|
||||||
|
"val_loss": 0.8974583646500499,
|
||||||
|
"val_rmse": 0.9473427898099364,
|
||||||
|
"val_r2": 0.22951567180333565,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 9,
|
||||||
|
"train_loss": 0.44227657012969945,
|
||||||
|
"val_loss": 0.8423525478929029,
|
||||||
|
"val_rmse": 0.917797665041863,
|
||||||
|
"val_r2": 0.2768250098825846,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 10,
|
||||||
|
"train_loss": 0.4192214831283541,
|
||||||
|
"val_loss": 0.8834660291599854,
|
||||||
|
"val_rmse": 0.9399287338143215,
|
||||||
|
"val_r2": 0.24152834743071827,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 11,
|
||||||
|
"train_loss": 0.4261854484762579,
|
||||||
|
"val_loss": 0.8901423802745356,
|
||||||
|
"val_rmse": 0.9434735654804721,
|
||||||
|
"val_r2": 0.23579658456782115,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 12,
|
||||||
|
"train_loss": 0.40934500416213515,
|
||||||
|
"val_loss": 0.8902992367456848,
|
||||||
|
"val_rmse": 0.9435566978740579,
|
||||||
|
"val_r2": 0.2356619059496775,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 13,
|
||||||
|
"train_loss": 0.3993083212922134,
|
||||||
|
"val_loss": 0.8952986713233702,
|
||||||
|
"val_rmse": 0.946202241268488,
|
||||||
|
"val_r2": 0.23136979638836497,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 14,
|
||||||
|
"train_loss": 0.38611710345412054,
|
||||||
|
"val_loss": 0.898042505591281,
|
||||||
|
"val_rmse": 0.9476510417548493,
|
||||||
|
"val_r2": 0.22901418082171254,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 15,
|
||||||
|
"train_loss": 0.3662278820512592,
|
||||||
|
"val_loss": 0.8985377061208856,
|
||||||
|
"val_rmse": 0.9479122944969771,
|
||||||
|
"val_r2": 0.2285890244824833,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 16,
|
||||||
|
"train_loss": 0.33336883667983097,
|
||||||
|
"val_loss": 0.8594161927771942,
|
||||||
|
"val_rmse": 0.9270470292726928,
|
||||||
|
"val_r2": 0.26217556409941023,
|
||||||
|
"lr": 5e-05
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 17,
|
||||||
|
"train_loss": 0.33174229304183017,
|
||||||
|
"val_loss": 0.8942263270711439,
|
||||||
|
"val_rmse": 0.945635409387207,
|
||||||
|
"val_r2": 0.23229043171319064,
|
||||||
|
"lr": 5e-05
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 18,
|
||||||
|
"train_loss": 0.3200935872264761,
|
||||||
|
"val_loss": 0.8843722421096231,
|
||||||
|
"val_rmse": 0.9404106786278892,
|
||||||
|
"val_r2": 0.24075034122444328,
|
||||||
|
"lr": 5e-05
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 19,
|
||||||
|
"train_loss": 0.3140528901624022,
|
||||||
|
"val_loss": 0.8707393039336969,
|
||||||
|
"val_rmse": 0.9331341268749274,
|
||||||
|
"val_r2": 0.2524544731269053,
|
||||||
|
"lr": 5e-05
|
||||||
|
}
|
||||||
|
]
|
||||||
BIN
models/pretrain_cv/fold_0/model.pt
Normal file
BIN
models/pretrain_cv/fold_0/model.pt
Normal file
Binary file not shown.
130
models/pretrain_cv/fold_1/history.json
Normal file
130
models/pretrain_cv/fold_1/history.json
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"epoch": 1,
|
||||||
|
"train_loss": 0.7586579631889938,
|
||||||
|
"val_loss": 0.7181247283430661,
|
||||||
|
"val_rmse": 0.8474223956247073,
|
||||||
|
"val_r2": 0.27072978816447013,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 2,
|
||||||
|
"train_loss": 0.6169086206378913,
|
||||||
|
"val_loss": 0.7122436411609591,
|
||||||
|
"val_rmse": 0.8439452896464879,
|
||||||
|
"val_r2": 0.27670212861312427,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 3,
|
||||||
|
"train_loss": 0.5716009976577638,
|
||||||
|
"val_loss": 0.724399374180903,
|
||||||
|
"val_rmse": 0.8511165408178202,
|
||||||
|
"val_r2": 0.2643577544134178,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 4,
|
||||||
|
"train_loss": 0.5406081575331337,
|
||||||
|
"val_loss": 0.8086859301147815,
|
||||||
|
"val_rmse": 0.8992696715048705,
|
||||||
|
"val_r2": 0.17876302728821758,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 5,
|
||||||
|
"train_loss": 0.5032305184360722,
|
||||||
|
"val_loss": 0.7551608490501026,
|
||||||
|
"val_rmse": 0.8689999128971181,
|
||||||
|
"val_r2": 0.23311884509134373,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 6,
|
||||||
|
"train_loss": 0.49779197957664073,
|
||||||
|
"val_loss": 0.7009093638175043,
|
||||||
|
"val_rmse": 0.8372033005770025,
|
||||||
|
"val_r2": 0.2882123252930564,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 7,
|
||||||
|
"train_loss": 0.4720957023516007,
|
||||||
|
"val_loss": 0.7939711763762837,
|
||||||
|
"val_rmse": 0.8910506016905562,
|
||||||
|
"val_r2": 0.1937061718828037,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 8,
|
||||||
|
"train_loss": 0.45880254612337634,
|
||||||
|
"val_loss": 0.7781202286020521,
|
||||||
|
"val_rmse": 0.8821112329914439,
|
||||||
|
"val_r2": 0.209803130396238,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 9,
|
||||||
|
"train_loss": 0.43299420961863133,
|
||||||
|
"val_loss": 0.7305147618332145,
|
||||||
|
"val_rmse": 0.8547015607274253,
|
||||||
|
"val_r2": 0.25814744997562133,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 10,
|
||||||
|
"train_loss": 0.4271428178197404,
|
||||||
|
"val_loss": 0.736324011356838,
|
||||||
|
"val_rmse": 0.8580932378454483,
|
||||||
|
"val_r2": 0.25224804192225836,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 11,
|
||||||
|
"train_loss": 0.4049153788182662,
|
||||||
|
"val_loss": 0.7328011674777642,
|
||||||
|
"val_rmse": 0.8560380622891226,
|
||||||
|
"val_r2": 0.2558255581383374,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 12,
|
||||||
|
"train_loss": 0.4011265030265379,
|
||||||
|
"val_loss": 0.8308254960889787,
|
||||||
|
"val_rmse": 0.9114962980441127,
|
||||||
|
"val_r2": 0.15627985591441118,
|
||||||
|
"lr": 0.0001
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 13,
|
||||||
|
"train_loss": 0.36400350090994316,
|
||||||
|
"val_loss": 0.7314743397036573,
|
||||||
|
"val_rmse": 0.8552627260066816,
|
||||||
|
"val_r2": 0.25717298455584947,
|
||||||
|
"lr": 5e-05
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 14,
|
||||||
|
"train_loss": 0.35224626399225445,
|
||||||
|
"val_loss": 0.7428758944520272,
|
||||||
|
"val_rmse": 0.8619024879725893,
|
||||||
|
"val_r2": 0.24559446076977343,
|
||||||
|
"lr": 5e-05
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 15,
|
||||||
|
"train_loss": 0.33604996105846857,
|
||||||
|
"val_loss": 0.7853316709722159,
|
||||||
|
"val_rmse": 0.8861894055728078,
|
||||||
|
"val_r2": 0.20247977173777976,
|
||||||
|
"lr": 5e-05
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"epoch": 16,
|
||||||
|
"train_loss": 0.3375160908226994,
|
||||||
|
"val_loss": 0.7890364486366602,
|
||||||
|
"val_rmse": 0.888277234295553,
|
||||||
|
"val_r2": 0.19871749006650596,
|
||||||
|
"lr": 5e-05
|
||||||
|
}
|
||||||
|
]
|
||||||
BIN
models/pretrain_cv/fold_1/model.pt
Normal file
BIN
models/pretrain_cv/fold_1/model.pt
Normal file
Binary file not shown.
Binary file not shown.
@ -1,285 +1,309 @@
|
|||||||
{
|
{
|
||||||
"train": [
|
"train": [
|
||||||
{
|
{
|
||||||
"loss": 0.8151603855680482,
|
"loss": 0.8244676398801744,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6867990841792819,
|
"loss": 0.6991508170533461,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.645540308540888,
|
"loss": 0.6388374940987616,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5923176541020599,
|
"loss": 0.6008581508669937,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5720762926262872,
|
"loss": 0.584832567446085,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5477570670417328,
|
"loss": 0.5481657371815157,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5280393017717573,
|
"loss": 0.5368926340308079,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5122504676513313,
|
"loss": 0.5210388793613561,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.49667307051028314,
|
"loss": 0.49758357966374045,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.486139440352648,
|
"loss": 0.49256294099457043,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4749755339466122,
|
"loss": 0.4697267088016886,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4636757543530298,
|
"loss": 0.45763822707571084,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4543497681877452,
|
"loss": 0.4495221330627172,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4408158337956461,
|
"loss": 0.446159594079631,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4419790126221837,
|
"loss": 0.4327090857889029,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.42850686623585116,
|
"loss": 0.4249273364101852,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.41607048387867007,
|
"loss": 0.4216959138704459,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.427172136486513,
|
"loss": 0.416526201182502,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.4125568530569382,
|
"loss": 0.40368679039741573,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.39480836287767923,
|
"loss": 0.4051084730032182,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3885056775666858,
|
"loss": 0.38971701020385785,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3894976457588827,
|
"loss": 0.39155546386038786,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3890058272899995,
|
"loss": 0.37976963541784114,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3741690826284791,
|
"loss": 0.36484339719805037,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3534914434345719,
|
"loss": 0.36232607571196496,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3349389765134386,
|
"loss": 0.3345973272380199,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.32965143874976194,
|
"loss": 0.31767916518768957,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.32094062546116675,
|
"loss": 0.32065429246052457,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.32526135008251184,
|
"loss": 0.3171297926146043,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.31289531808423826,
|
"loss": 0.3122120894173009,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.3088379208288558,
|
"loss": 0.3135035038404461,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.2994744991261045,
|
"loss": 0.2987745178222875,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.2981521815160671,
|
"loss": 0.2914867957853393,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.29143649446979303,
|
"loss": 0.2983839795507705,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.29075756723379653,
|
"loss": 0.2826709597875678,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.2731766632569382,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.27726896305742266,
|
||||||
|
"n_samples": 8721
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.27864557847067956,
|
||||||
"n_samples": 8721
|
"n_samples": 8721
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"val": [
|
"val": [
|
||||||
{
|
{
|
||||||
"loss": 0.7625711447683281,
|
"loss": 0.7601077516012517,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.7092331695236781,
|
"loss": 0.7119935319611901,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.7014068689723995,
|
"loss": 0.6461842978148269,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6595172673863646,
|
"loss": 0.7006978391063226,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6312279044905191,
|
"loss": 0.6533874032943979,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6349272860831151,
|
"loss": 0.6413641451743611,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6587623598744133,
|
"loss": 0.6168395132979742,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6093261837651732,
|
"loss": 0.6095251602162025,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6125607111474924,
|
"loss": 0.5887809592626905,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6005943137518024,
|
"loss": 0.5655298325376368,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6876292386783289,
|
"loss": 0.5809201743872788,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5940848466228036,
|
"loss": 0.5897585974912033,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5820883587079644,
|
"loss": 0.5732012489662573,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6302792748938035,
|
"loss": 0.5607388911786094,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5849901610914275,
|
"loss": 0.5717580675371414,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5830434826428553,
|
"loss": 0.5553950037657291,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5643168952858116,
|
"loss": 0.5778171792857049,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5592790719340829,
|
"loss": 0.5602665468127734,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.600335100686833,
|
"loss": 0.5475307451359259,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5646457721097674,
|
"loss": 0.551515599314827,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.6288956836004376,
|
"loss": 0.5755438121541243,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5771863222183704,
|
"loss": 0.5798238261811381,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5738056593250687,
|
"loss": 0.5739961433828923,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5636531712085593,
|
"loss": 0.5742599932540312,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5465074849879163,
|
"loss": 0.5834948123885382,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5701294839843508,
|
"loss": 0.554078846570139,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5570075802420438,
|
"loss": 0.5714933996322354,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5711473401701241,
|
"loss": 0.5384107524350331,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5576858864741552,
|
"loss": 0.570854394451568,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5624132422716871,
|
"loss": 0.5767292551642478,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5655298555506272,
|
"loss": 0.5660079547556808,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5568078993151677,
|
"loss": 0.5608972411514312,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.567752199383958,
|
"loss": 0.5620947442987263,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5683093779442603,
|
"loss": 0.5706970894361305,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"loss": 0.5741443767974497,
|
"loss": 0.5702376298690974,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5758474825259579,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5673816067284844,
|
||||||
|
"n_samples": 969
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"loss": 0.5671441179879925,
|
||||||
"n_samples": 969
|
"n_samples": 969
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
18
models/pretrain_test_results.json
Normal file
18
models/pretrain_test_results.json
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
{
|
||||||
|
"model_path": "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/models/pretrain_delivery.pt",
|
||||||
|
"val_path": "/Users/ryde/Documents/workspaces/\u8102\u8d28\u5206\u5b50\u836f\u7269\u9012\u9001\u6027\u80fd\u9884\u6d4b/\u6700\u65b0\u6574\u7406/lnp-ml/data/processed/val_pretrain.parquet",
|
||||||
|
"n_samples": 969,
|
||||||
|
"metrics": {
|
||||||
|
"mse": 0.5384107828140259,
|
||||||
|
"rmse": 0.7337648007461423,
|
||||||
|
"mae": 0.5158410668373108,
|
||||||
|
"r2": 0.4877620374678041,
|
||||||
|
"correlation": 0.7059544816157891
|
||||||
|
},
|
||||||
|
"statistics": {
|
||||||
|
"y_true_mean": 0.009126194752752781,
|
||||||
|
"y_true_std": 1.0252292156219482,
|
||||||
|
"y_pred_mean": 0.008337541483342648,
|
||||||
|
"y_pred_std": 0.8293642997741699
|
||||||
|
}
|
||||||
|
}
|
||||||
234
scripts/process_data_cv.py
Normal file
234
scripts/process_data_cv.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
"""处理 cross-validation 数据脚本:将 CV splits 转换为模型所需的 parquet 格式"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import typer
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lnp_ml.config import EXTERNAL_DATA_DIR, PROCESSED_DATA_DIR
|
||||||
|
from lnp_ml.dataset import (
|
||||||
|
LNPDatasetConfig,
|
||||||
|
COMP_COLS,
|
||||||
|
HELP_COLS,
|
||||||
|
get_phys_cols,
|
||||||
|
get_exp_cols,
|
||||||
|
EXP_ONEHOT_SPECS,
|
||||||
|
PHYS_ONEHOT_SPECS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
# CV extra_x 列名到模型列名的映射
|
||||||
|
CV_COL_MAPPING = {
|
||||||
|
# Batch_or_individual_or_barcoded -> Sample_organization_type (for Value_name related)
|
||||||
|
"Batch_or_individual_or_barcoded_Barcoded": "Batch_or_individual_or_barcoded_Barcoded",
|
||||||
|
"Batch_or_individual_or_barcoded_Individual": "Batch_or_individual_or_barcoded_Individual",
|
||||||
|
# Helper_lipid_ID_None 不在模型中使用,忽略
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_cv_split(cv_dir: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
加载单个 CV split 的数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cv_dir: CV split 目录,如 cv_0/
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(train_df, valid_df, test_df) 合并后的 DataFrame
|
||||||
|
"""
|
||||||
|
splits = {}
|
||||||
|
for split_name in ["train", "valid", "test"]:
|
||||||
|
# 加载主数据(smiles, quantified_delivery)
|
||||||
|
main_path = cv_dir / f"{split_name}.csv"
|
||||||
|
extra_x_path = cv_dir / f"{split_name}_extra_x.csv"
|
||||||
|
metadata_path = cv_dir / f"{split_name}_metadata.csv"
|
||||||
|
|
||||||
|
if not main_path.exists():
|
||||||
|
raise FileNotFoundError(f"Missing {main_path}")
|
||||||
|
|
||||||
|
main_df = pd.read_csv(main_path)
|
||||||
|
|
||||||
|
# 加载 extra_x(已 one-hot 编码的特征)
|
||||||
|
if extra_x_path.exists():
|
||||||
|
extra_x_df = pd.read_csv(extra_x_path)
|
||||||
|
# 确保行数一致
|
||||||
|
assert len(main_df) == len(extra_x_df), f"Row count mismatch: {len(main_df)} vs {len(extra_x_df)}"
|
||||||
|
# 合并(按行索引)
|
||||||
|
df = pd.concat([main_df, extra_x_df], axis=1)
|
||||||
|
else:
|
||||||
|
df = main_df
|
||||||
|
logger.warning(f" {split_name}_extra_x.csv not found, using main data only")
|
||||||
|
|
||||||
|
# 可选:从 metadata 获取额外信息
|
||||||
|
if metadata_path.exists():
|
||||||
|
metadata_df = pd.read_csv(metadata_path)
|
||||||
|
# 提取需要的列(如 Purity, Mix_type, Value_name 等)
|
||||||
|
for col in ["Purity", "Mix_type", "Value_name", "Target_or_delivered_gene"]:
|
||||||
|
if col in metadata_df.columns and col not in df.columns:
|
||||||
|
df[col] = metadata_df[col]
|
||||||
|
|
||||||
|
splits[split_name] = df
|
||||||
|
|
||||||
|
return splits["train"], splits["valid"], splits["test"]
|
||||||
|
|
||||||
|
|
||||||
|
def process_cv_dataframe(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
处理 CV 数据的 DataFrame,对齐到模型所需的列格式。
|
||||||
|
|
||||||
|
CV 数据的 extra_x 已经包含大部分 one-hot 编码,但需要:
|
||||||
|
1. 添加缺失的 one-hot 列(设为 0)
|
||||||
|
2. 从 metadata 中生成 phys token 的 one-hot 列(Purity, Mix_type, Cargo_type, Target_or_delivered_gene)
|
||||||
|
3. 生成 Value_name 的 one-hot 列
|
||||||
|
"""
|
||||||
|
df = df.copy()
|
||||||
|
|
||||||
|
# 1. 处理 comp 列
|
||||||
|
for col in COMP_COLS:
|
||||||
|
if col in df.columns:
|
||||||
|
df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0.0)
|
||||||
|
else:
|
||||||
|
df[col] = 0.0
|
||||||
|
|
||||||
|
# 2. 处理 help 列
|
||||||
|
for col in HELP_COLS:
|
||||||
|
if col not in df.columns:
|
||||||
|
df[col] = 0.0
|
||||||
|
else:
|
||||||
|
df[col] = df[col].fillna(0.0).astype(float)
|
||||||
|
|
||||||
|
# 3. 处理 phys token 的 one-hot 列
|
||||||
|
for col, values in PHYS_ONEHOT_SPECS.items():
|
||||||
|
for v in values:
|
||||||
|
onehot_col = f"{col}_{v}"
|
||||||
|
if onehot_col not in df.columns:
|
||||||
|
# 尝试从原始列生成
|
||||||
|
if col in df.columns:
|
||||||
|
df[onehot_col] = (df[col] == v).astype(float)
|
||||||
|
else:
|
||||||
|
df[onehot_col] = 0.0
|
||||||
|
else:
|
||||||
|
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
|
||||||
|
|
||||||
|
# 4. 处理 exp token 的 one-hot 列
|
||||||
|
for col, values in EXP_ONEHOT_SPECS.items():
|
||||||
|
for v in values:
|
||||||
|
onehot_col = f"{col}_{v}"
|
||||||
|
if onehot_col not in df.columns:
|
||||||
|
# 尝试从原始列生成
|
||||||
|
if col in df.columns:
|
||||||
|
df[onehot_col] = (df[col] == v).astype(float)
|
||||||
|
else:
|
||||||
|
df[onehot_col] = 0.0
|
||||||
|
else:
|
||||||
|
df[onehot_col] = df[onehot_col].fillna(0.0).astype(float)
|
||||||
|
|
||||||
|
# 5. 处理 quantified_delivery
|
||||||
|
if "quantified_delivery" in df.columns:
|
||||||
|
df["quantified_delivery"] = pd.to_numeric(df["quantified_delivery"], errors="coerce")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_columns() -> List[str]:
|
||||||
|
"""获取所有特征列名"""
|
||||||
|
config = LNPDatasetConfig()
|
||||||
|
return (
|
||||||
|
["smiles"]
|
||||||
|
+ config.comp_cols
|
||||||
|
+ config.phys_cols
|
||||||
|
+ config.help_cols
|
||||||
|
+ config.exp_cols
|
||||||
|
+ ["quantified_delivery"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def main(
|
||||||
|
data_dir: Path = EXTERNAL_DATA_DIR / "all_amine_split_for_LiON",
|
||||||
|
output_dir: Path = PROCESSED_DATA_DIR / "cv",
|
||||||
|
n_folds: int = 5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
处理 cross-validation 数据,生成模型所需的 parquet 文件。
|
||||||
|
|
||||||
|
输出结构:
|
||||||
|
- processed/cv/fold_0/train.parquet
|
||||||
|
- processed/cv/fold_0/valid.parquet
|
||||||
|
- processed/cv/fold_0/test.parquet
|
||||||
|
- processed/cv/fold_1/...
|
||||||
|
- processed/cv/feature_columns.txt
|
||||||
|
"""
|
||||||
|
logger.info(f"Processing CV data from {data_dir}")
|
||||||
|
|
||||||
|
# 获取所有 cv_* 目录
|
||||||
|
cv_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith("cv_")])
|
||||||
|
|
||||||
|
if len(cv_dirs) == 0:
|
||||||
|
logger.error(f"No cv_* directories found in {data_dir}")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if len(cv_dirs) != n_folds:
|
||||||
|
logger.warning(f"Expected {n_folds} folds, found {len(cv_dirs)}")
|
||||||
|
|
||||||
|
logger.info(f"Found {len(cv_dirs)} folds: {[d.name for d in cv_dirs]}")
|
||||||
|
|
||||||
|
feature_cols = get_feature_columns()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for i, cv_dir in enumerate(cv_dirs):
|
||||||
|
fold_name = f"fold_{i}"
|
||||||
|
fold_output_dir = output_dir / fold_name
|
||||||
|
fold_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info(f"Processing {cv_dir.name} -> {fold_name}")
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
train_df, valid_df, test_df = load_cv_split(cv_dir)
|
||||||
|
|
||||||
|
logger.info(f" Loaded: train={len(train_df)}, valid={len(valid_df)}, test={len(test_df)}")
|
||||||
|
|
||||||
|
# 处理数据
|
||||||
|
train_df = process_cv_dataframe(train_df)
|
||||||
|
valid_df = process_cv_dataframe(valid_df)
|
||||||
|
test_df = process_cv_dataframe(test_df)
|
||||||
|
|
||||||
|
# 确保所有列存在
|
||||||
|
for col in feature_cols:
|
||||||
|
for df in [train_df, valid_df, test_df]:
|
||||||
|
if col not in df.columns:
|
||||||
|
df[col] = 0.0 if col != "smiles" else ""
|
||||||
|
|
||||||
|
# 只保留需要的列
|
||||||
|
train_df = train_df[feature_cols]
|
||||||
|
valid_df = valid_df[feature_cols]
|
||||||
|
test_df = test_df[feature_cols]
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
train_df.to_parquet(fold_output_dir / "train.parquet", index=False)
|
||||||
|
valid_df.to_parquet(fold_output_dir / "valid.parquet", index=False)
|
||||||
|
test_df.to_parquet(fold_output_dir / "test.parquet", index=False)
|
||||||
|
|
||||||
|
logger.success(f" Saved to {fold_output_dir}")
|
||||||
|
|
||||||
|
# 保存特征列配置
|
||||||
|
cols_path = output_dir / "feature_columns.txt"
|
||||||
|
with open(cols_path, "w") as f:
|
||||||
|
f.write("\n".join(feature_cols))
|
||||||
|
logger.success(f"Saved feature columns to {cols_path}")
|
||||||
|
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("CV DATA PROCESSING COMPLETE")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"Output directory: {output_dir}")
|
||||||
|
logger.info(f"Number of folds: {len(cv_dirs)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user