mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-22 01:56:54 +08:00
794 lines
27 KiB
Python
794 lines
27 KiB
Python
"""
|
||
嵌套交叉验证 + Optuna 超参调优
|
||
|
||
外层 5-fold StratifiedKFold(20% test / 80% train)
|
||
内层 3-fold StratifiedKFold(在 80% 上做 Optuna 超参搜索)
|
||
|
||
使用方法:
|
||
python -m lnp_ml.modeling.nested_cv_optuna
|
||
|
||
或通过 Makefile:
|
||
make nested_cv_tune DEVICE=cuda
|
||
"""
|
||
|
||
import json
|
||
import math
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional, Tuple, Union
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from torch.utils.data import DataLoader, Subset
|
||
from sklearn.model_selection import StratifiedKFold
|
||
from loguru import logger
|
||
import typer
|
||
|
||
try:
|
||
import optuna
|
||
from optuna.samplers import TPESampler
|
||
except ImportError:
|
||
optuna = None
|
||
TPESampler = None
|
||
|
||
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR
|
||
from lnp_ml.dataset import (
|
||
LNPDataset,
|
||
collate_fn,
|
||
process_dataframe,
|
||
TARGET_CLASSIFICATION_PDI,
|
||
TARGET_CLASSIFICATION_EE,
|
||
TARGET_TOXIC,
|
||
)
|
||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||
from lnp_ml.modeling.trainer_balanced import (
|
||
ClassWeights,
|
||
LossWeightsBalanced,
|
||
compute_class_weights_from_loader,
|
||
train_with_early_stopping,
|
||
train_fixed_epochs,
|
||
validate_balanced,
|
||
)
|
||
|
||
# MPNN ensemble 默认路径
|
||
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
|
||
|
||
app = typer.Typer()
|
||
|
||
|
||
# ============ CompositeStrata 复合分层标签 ============
|
||
|
||
def build_composite_strata(
|
||
df: pd.DataFrame,
|
||
min_stratum_count: int = 5,
|
||
) -> Tuple[np.ndarray, Dict]:
|
||
"""
|
||
构建复合分层标签(toxic × PDI × EE)。
|
||
|
||
Args:
|
||
df: 处理后的 DataFrame
|
||
min_stratum_count: 每个 stratum 最少样本数,低于此值合并为 RARE
|
||
|
||
Returns:
|
||
(strata_array, strata_info)
|
||
- strata_array: 每个样本的 stratum 编码(整数)
|
||
- strata_info: 统计信息
|
||
"""
|
||
n = len(df)
|
||
strata_labels = []
|
||
|
||
for i in range(n):
|
||
# Toxic stratum
|
||
if TARGET_TOXIC in df.columns:
|
||
toxic_val = df[TARGET_TOXIC].iloc[i]
|
||
if pd.notna(toxic_val) and toxic_val >= 0:
|
||
toxic_str = str(int(toxic_val))
|
||
else:
|
||
toxic_str = "NA"
|
||
else:
|
||
toxic_str = "NA"
|
||
|
||
# PDI stratum
|
||
if all(col in df.columns for col in TARGET_CLASSIFICATION_PDI):
|
||
pdi_vals = df[TARGET_CLASSIFICATION_PDI].iloc[i].values
|
||
if pdi_vals.sum() > 0:
|
||
pdi_str = str(int(np.argmax(pdi_vals)))
|
||
else:
|
||
pdi_str = "NA"
|
||
else:
|
||
pdi_str = "NA"
|
||
|
||
# EE stratum
|
||
if all(col in df.columns for col in TARGET_CLASSIFICATION_EE):
|
||
ee_vals = df[TARGET_CLASSIFICATION_EE].iloc[i].values
|
||
if ee_vals.sum() > 0:
|
||
ee_str = str(int(np.argmax(ee_vals)))
|
||
else:
|
||
ee_str = "NA"
|
||
else:
|
||
ee_str = "NA"
|
||
|
||
strata_labels.append(f"T{toxic_str}|P{pdi_str}|E{ee_str}")
|
||
|
||
# 统计各 stratum 的样本数
|
||
unique_strata, counts = np.unique(strata_labels, return_counts=True)
|
||
strata_counts = dict(zip(unique_strata, counts))
|
||
|
||
# 将稀疏 strata 合并为 RARE
|
||
rare_strata = [s for s, c in strata_counts.items() if c < min_stratum_count]
|
||
|
||
final_labels = []
|
||
for label in strata_labels:
|
||
if label in rare_strata:
|
||
final_labels.append("RARE")
|
||
else:
|
||
final_labels.append(label)
|
||
|
||
# 编码为整数
|
||
unique_final, encoded = np.unique(final_labels, return_inverse=True)
|
||
|
||
strata_info = {
|
||
"original_strata_counts": strata_counts,
|
||
"rare_strata": rare_strata,
|
||
"final_strata": list(unique_final),
|
||
"final_strata_counts": dict(zip(*np.unique(final_labels, return_counts=True))),
|
||
"n_rare_merged": sum(strata_counts[s] for s in rare_strata) if rare_strata else 0,
|
||
}
|
||
|
||
logger.info(f"Built composite strata: {len(unique_final)} unique strata")
|
||
logger.info(f" Rare strata merged: {len(rare_strata)} types, {strata_info['n_rare_merged']} samples")
|
||
|
||
return encoded.astype(np.int64), strata_info
|
||
|
||
|
||
# ============ 模型创建 ============
|
||
|
||
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
|
||
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
|
||
if not model_paths:
|
||
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
|
||
return [str(p) for p in model_paths]
|
||
|
||
|
||
def create_model(
|
||
d_model: int = 256,
|
||
num_heads: int = 8,
|
||
n_attn_layers: int = 4,
|
||
fusion_strategy: str = "attention",
|
||
head_hidden_dim: int = 128,
|
||
dropout: float = 0.1,
|
||
use_mpnn: bool = False,
|
||
mpnn_device: str = "cpu",
|
||
) -> Union[LNPModel, LNPModelWithoutMPNN]:
|
||
"""创建模型"""
|
||
if use_mpnn:
|
||
ensemble_paths = find_mpnn_ensemble_paths()
|
||
return LNPModel(
|
||
d_model=d_model,
|
||
num_heads=num_heads,
|
||
n_attn_layers=n_attn_layers,
|
||
fusion_strategy=fusion_strategy,
|
||
head_hidden_dim=head_hidden_dim,
|
||
dropout=dropout,
|
||
mpnn_ensemble_paths=ensemble_paths,
|
||
mpnn_device=mpnn_device,
|
||
)
|
||
else:
|
||
return LNPModelWithoutMPNN(
|
||
d_model=d_model,
|
||
num_heads=num_heads,
|
||
n_attn_layers=n_attn_layers,
|
||
fusion_strategy=fusion_strategy,
|
||
head_hidden_dim=head_hidden_dim,
|
||
dropout=dropout,
|
||
)
|
||
|
||
|
||
# ============ 评估指标 ============
|
||
|
||
def evaluate_on_test(
|
||
model: torch.nn.Module,
|
||
test_loader: DataLoader,
|
||
device: torch.device,
|
||
) -> Dict:
|
||
"""在测试集上评估模型"""
|
||
from scipy.special import rel_entr
|
||
from sklearn.metrics import (
|
||
mean_squared_error,
|
||
mean_absolute_error,
|
||
r2_score,
|
||
accuracy_score,
|
||
precision_score,
|
||
recall_score,
|
||
f1_score,
|
||
)
|
||
|
||
model.eval()
|
||
|
||
preds = {
|
||
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
|
||
}
|
||
targets = {
|
||
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
|
||
}
|
||
|
||
with torch.no_grad():
|
||
for batch in test_loader:
|
||
smiles = batch["smiles"]
|
||
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
|
||
tgts = batch["targets"]
|
||
masks = batch["mask"]
|
||
|
||
outputs = model(smiles, tabular)
|
||
|
||
# 收集预测和真实值
|
||
for task in ["size", "delivery"]:
|
||
if task in masks and masks[task].any():
|
||
m = masks[task]
|
||
key = task if task == "size" else "delivery"
|
||
preds[task].extend(outputs[key].squeeze(-1)[m].cpu().numpy().tolist())
|
||
targets[task].extend(tgts[key][m].cpu().numpy().tolist())
|
||
|
||
for task in ["pdi", "ee", "toxic"]:
|
||
if task in masks and masks[task].any():
|
||
m = masks[task]
|
||
preds[task].extend(outputs[task][m].argmax(dim=-1).cpu().numpy().tolist())
|
||
targets[task].extend(tgts[task][m].cpu().numpy().tolist())
|
||
|
||
if "biodist" in masks and masks["biodist"].any():
|
||
m = masks["biodist"]
|
||
preds["biodist"].extend(outputs["biodist"][m].cpu().numpy().tolist())
|
||
targets["biodist"].extend(tgts["biodist"][m].cpu().numpy().tolist())
|
||
|
||
# 计算指标
|
||
results = {}
|
||
|
||
# 回归任务
|
||
for task in ["size", "delivery"]:
|
||
if preds[task]:
|
||
p = np.array(preds[task])
|
||
t = np.array(targets[task])
|
||
results[task] = {
|
||
"n_samples": len(p),
|
||
"mse": float(mean_squared_error(t, p)),
|
||
"rmse": float(np.sqrt(mean_squared_error(t, p))),
|
||
"mae": float(mean_absolute_error(t, p)),
|
||
"r2": float(r2_score(t, p)),
|
||
}
|
||
|
||
# 分类任务
|
||
for task in ["pdi", "ee", "toxic"]:
|
||
if preds[task]:
|
||
p = np.array(preds[task])
|
||
t = np.array(targets[task])
|
||
results[task] = {
|
||
"n_samples": len(p),
|
||
"accuracy": float(accuracy_score(t, p)),
|
||
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
|
||
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
|
||
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
|
||
}
|
||
|
||
# 分布任务
|
||
if preds["biodist"]:
|
||
p = np.array(preds["biodist"])
|
||
t = np.array(targets["biodist"])
|
||
|
||
def kl_divergence(p_arr, q_arr, eps=1e-10):
|
||
p_arr = np.clip(p_arr, eps, 1.0)
|
||
q_arr = np.clip(q_arr, eps, 1.0)
|
||
return float(np.sum(rel_entr(p_arr, q_arr), axis=-1).mean())
|
||
|
||
def js_divergence(p_arr, q_arr, eps=1e-10):
|
||
p_arr = np.clip(p_arr, eps, 1.0)
|
||
q_arr = np.clip(q_arr, eps, 1.0)
|
||
m = 0.5 * (p_arr + q_arr)
|
||
return float(0.5 * (np.sum(rel_entr(p_arr, m), axis=-1) + np.sum(rel_entr(q_arr, m), axis=-1)).mean())
|
||
|
||
results["biodist"] = {
|
||
"n_samples": len(p),
|
||
"kl_divergence": kl_divergence(t, p),
|
||
"js_divergence": js_divergence(t, p),
|
||
}
|
||
|
||
return results
|
||
|
||
|
||
# ============ 预训练权重加载 ============
|
||
|
||
def load_pretrain_weights_to_model(
|
||
model: Union[LNPModel, LNPModelWithoutMPNN],
|
||
pretrain_state_dict: Dict,
|
||
d_model: int,
|
||
pretrain_config: Dict,
|
||
load_delivery_head: bool = True,
|
||
) -> bool:
|
||
"""
|
||
加载预训练权重到模型。
|
||
|
||
Returns:
|
||
是否成功加载
|
||
"""
|
||
if pretrain_config.get("d_model") != d_model:
|
||
logger.warning(
|
||
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
|
||
f"current={d_model}. Skipping pretrain loading."
|
||
)
|
||
return False
|
||
|
||
model.load_pretrain_weights(
|
||
pretrain_state_dict=pretrain_state_dict,
|
||
load_delivery_head=load_delivery_head,
|
||
strict=False,
|
||
)
|
||
return True
|
||
|
||
|
||
# ============ 内层 Optuna 调参 ============
|
||
|
||
def run_inner_optuna(
|
||
full_dataset: LNPDataset,
|
||
inner_train_indices: np.ndarray,
|
||
strata: np.ndarray,
|
||
device: torch.device,
|
||
n_trials: int = 20,
|
||
epochs_per_trial: int = 30,
|
||
patience: int = 10,
|
||
batch_size: int = 32,
|
||
n_inner_folds: int = 3,
|
||
use_mpnn: bool = False,
|
||
seed: int = 42,
|
||
study_path: Optional[Path] = None,
|
||
pretrain_state_dict: Optional[Dict] = None,
|
||
pretrain_config: Optional[Dict] = None,
|
||
load_delivery_head: bool = True,
|
||
) -> Tuple[Dict, int, optuna.Study]:
|
||
"""
|
||
在内层数据上运行 Optuna 超参搜索。
|
||
|
||
Args:
|
||
full_dataset: 完整数据集
|
||
inner_train_indices: 内层训练数据的索引(相对于 full_dataset)
|
||
strata: 每个样本的分层标签
|
||
device: 设备
|
||
n_trials: Optuna 试验数
|
||
epochs_per_trial: 每个试验的最大 epoch
|
||
patience: 早停耐心值
|
||
batch_size: 批次大小
|
||
n_inner_folds: 内层折数
|
||
use_mpnn: 是否使用 MPNN
|
||
seed: 随机种子
|
||
study_path: 可选的 study 持久化路径
|
||
pretrain_state_dict: 预训练权重
|
||
pretrain_config: 预训练配置
|
||
load_delivery_head: 是否加载 delivery head 权重
|
||
|
||
Returns:
|
||
(best_params, epoch_mean, study)
|
||
"""
|
||
if optuna is None:
|
||
raise ImportError("Optuna not installed. Run: pip install optuna")
|
||
|
||
inner_strata = strata[inner_train_indices]
|
||
|
||
# 固定架构参数(与预训练一致,确保权重完整加载)
|
||
_cfg = pretrain_config or {}
|
||
fixed_d_model = _cfg.get("d_model", 256)
|
||
fixed_num_heads = _cfg.get("num_heads", 8)
|
||
fixed_n_attn_layers = _cfg.get("n_attn_layers", 4)
|
||
fixed_fusion_strategy = _cfg.get("fusion_strategy", "attention")
|
||
fixed_head_hidden_dim = _cfg.get("head_hidden_dim", 128)
|
||
logger.info(
|
||
f"Fixed architecture params: d_model={fixed_d_model}, num_heads={fixed_num_heads}, "
|
||
f"n_attn_layers={fixed_n_attn_layers}, fusion={fixed_fusion_strategy}, "
|
||
f"head_hidden_dim={fixed_head_hidden_dim}"
|
||
)
|
||
|
||
def objective(trial: optuna.Trial) -> float:
|
||
d_model = fixed_d_model
|
||
num_heads = fixed_num_heads
|
||
n_attn_layers = fixed_n_attn_layers
|
||
fusion_strategy = fixed_fusion_strategy
|
||
head_hidden_dim = fixed_head_hidden_dim
|
||
|
||
# 搜索训练超参数
|
||
dropout = trial.suggest_float("dropout", 0.1, 0.5)
|
||
lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True)
|
||
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
|
||
|
||
# 内层 3-fold CV
|
||
inner_cv = StratifiedKFold(
|
||
n_splits=n_inner_folds, shuffle=True, random_state=seed
|
||
)
|
||
|
||
fold_val_losses = []
|
||
fold_best_epochs = []
|
||
|
||
for inner_fold, (inner_train_idx, inner_val_idx) in enumerate(
|
||
inner_cv.split(inner_train_indices, inner_strata)
|
||
):
|
||
# 获取实际的数据集索引
|
||
actual_train_idx = inner_train_indices[inner_train_idx]
|
||
actual_val_idx = inner_train_indices[inner_val_idx]
|
||
|
||
# 创建 DataLoader
|
||
train_subset = Subset(full_dataset, actual_train_idx.tolist())
|
||
val_subset = Subset(full_dataset, actual_val_idx.tolist())
|
||
|
||
train_loader = DataLoader(
|
||
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||
)
|
||
val_loader = DataLoader(
|
||
val_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||
)
|
||
|
||
# 计算类权重
|
||
class_weights = compute_class_weights_from_loader(train_loader)
|
||
|
||
# 创建模型
|
||
model = create_model(
|
||
d_model=d_model,
|
||
num_heads=num_heads,
|
||
n_attn_layers=n_attn_layers,
|
||
fusion_strategy=fusion_strategy,
|
||
head_hidden_dim=head_hidden_dim,
|
||
dropout=dropout,
|
||
use_mpnn=use_mpnn,
|
||
mpnn_device=device.type,
|
||
)
|
||
|
||
# 加载预训练权重
|
||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||
load_pretrain_weights_to_model(
|
||
model, pretrain_state_dict, d_model, pretrain_config, load_delivery_head
|
||
)
|
||
|
||
# 训练(带早停)
|
||
result = train_with_early_stopping(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
device=device,
|
||
lr=lr,
|
||
weight_decay=weight_decay,
|
||
epochs=epochs_per_trial,
|
||
patience=patience,
|
||
class_weights=class_weights,
|
||
)
|
||
|
||
fold_val_losses.append(result["best_val_loss"])
|
||
fold_best_epochs.append(result["best_epoch"])
|
||
|
||
# 记录 epoch_mean 到 trial
|
||
epoch_mean = int(round(np.mean(fold_best_epochs)))
|
||
trial.set_user_attr("epoch_mean", epoch_mean)
|
||
trial.set_user_attr("fold_best_epochs", fold_best_epochs)
|
||
|
||
return np.mean(fold_val_losses)
|
||
|
||
# 创建 study
|
||
storage = None
|
||
if study_path is not None:
|
||
storage = f"sqlite:///{study_path}"
|
||
|
||
study = optuna.create_study(
|
||
direction="minimize",
|
||
sampler=TPESampler(seed=seed),
|
||
storage=storage,
|
||
study_name="inner_optuna",
|
||
load_if_exists=True,
|
||
)
|
||
|
||
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
||
|
||
best_params = dict(study.best_trial.params)
|
||
best_params.update({
|
||
"d_model": fixed_d_model,
|
||
"num_heads": fixed_num_heads,
|
||
"n_attn_layers": fixed_n_attn_layers,
|
||
"fusion_strategy": fixed_fusion_strategy,
|
||
"head_hidden_dim": fixed_head_hidden_dim,
|
||
})
|
||
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
|
||
|
||
logger.info(f"Best trial: {study.best_trial.number}")
|
||
logger.info(f"Best val_loss: {study.best_trial.value:.4f}")
|
||
logger.info(f"Best params: {best_params}")
|
||
logger.info(f"Epoch mean: {epoch_mean}")
|
||
|
||
return best_params, epoch_mean, study
|
||
|
||
|
||
# ============ 主流程 ============
|
||
|
||
@app.command()
|
||
def main(
|
||
input_path: Path = INTERIM_DATA_DIR / "internal.csv",
|
||
output_dir: Path = MODELS_DIR / "nested_cv",
|
||
# CV 参数
|
||
n_outer_folds: int = 5,
|
||
n_inner_folds: int = 3,
|
||
min_stratum_count: int = 5,
|
||
seed: int = 42,
|
||
# Optuna 参数
|
||
n_trials: int = 20,
|
||
epochs_per_trial: int = 30,
|
||
inner_patience: int = 10,
|
||
# 训练参数
|
||
batch_size: int = 32,
|
||
# 预训练权重
|
||
init_from_pretrain: Optional[Path] = None,
|
||
load_delivery_head: bool = True,
|
||
# MPNN
|
||
use_mpnn: bool = False,
|
||
# 设备
|
||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||
):
|
||
"""
|
||
嵌套交叉验证 + Optuna 超参调优。
|
||
|
||
外层 5-fold(20% test / 80% train),内层 3-fold Optuna 调参。
|
||
外层训练不使用 early-stopping,epoch 数使用内层 best trial 的 epoch_mean。
|
||
|
||
使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重。
|
||
"""
|
||
if optuna is None:
|
||
logger.error("Optuna not installed. Run: pip install optuna")
|
||
raise typer.Exit(1)
|
||
|
||
logger.info(f"Using device: {device}")
|
||
device = torch.device(device)
|
||
|
||
# 加载预训练权重(如果指定)
|
||
pretrain_state_dict = None
|
||
pretrain_config = None
|
||
if init_from_pretrain is not None:
|
||
if init_from_pretrain.exists():
|
||
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
|
||
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
|
||
pretrain_state_dict = checkpoint["model_state_dict"]
|
||
pretrain_config = checkpoint.get("config", {})
|
||
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")
|
||
else:
|
||
logger.warning(f"Pretrain checkpoint not found: {init_from_pretrain}, skipping")
|
||
|
||
# 创建输出目录(带时间戳)
|
||
run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
run_dir = output_dir / run_name
|
||
run_dir.mkdir(parents=True, exist_ok=True)
|
||
logger.info(f"Output directory: {run_dir}")
|
||
|
||
# 加载数据
|
||
logger.info(f"Loading data from {input_path}")
|
||
df = pd.read_csv(input_path)
|
||
logger.info(f"Loaded {len(df)} samples")
|
||
|
||
# 处理数据
|
||
logger.info("Processing dataframe...")
|
||
df = process_dataframe(df)
|
||
|
||
# 构建复合分层标签
|
||
logger.info("Building composite strata...")
|
||
strata, strata_info = build_composite_strata(df, min_stratum_count)
|
||
|
||
# 保存 strata 信息
|
||
with open(run_dir / "strata_info.json", "w") as f:
|
||
json.dump(strata_info, f, indent=2, default=str)
|
||
|
||
# 创建完整数据集
|
||
full_dataset = LNPDataset(df)
|
||
n_samples = len(full_dataset)
|
||
|
||
# 外层 CV
|
||
outer_cv = StratifiedKFold(
|
||
n_splits=n_outer_folds, shuffle=True, random_state=seed
|
||
)
|
||
|
||
outer_results = []
|
||
|
||
for outer_fold, (outer_train_idx, outer_test_idx) in enumerate(
|
||
outer_cv.split(np.arange(n_samples), strata)
|
||
):
|
||
logger.info(f"\n{'='*60}")
|
||
logger.info(f"OUTER FOLD {outer_fold}")
|
||
logger.info(f"{'='*60}")
|
||
logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}")
|
||
|
||
fold_dir = run_dir / f"outer_fold_{outer_fold}"
|
||
fold_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 保存 split indices
|
||
splits = {
|
||
"outer_train_idx": outer_train_idx.tolist(),
|
||
"outer_test_idx": outer_test_idx.tolist(),
|
||
}
|
||
with open(fold_dir / "splits.json", "w") as f:
|
||
json.dump(splits, f)
|
||
|
||
# 内层 Optuna 调参
|
||
logger.info(f"\nRunning inner Optuna with {n_trials} trials...")
|
||
study_path = fold_dir / "optuna_study.sqlite3"
|
||
|
||
best_params, epoch_mean, study = run_inner_optuna(
|
||
full_dataset=full_dataset,
|
||
inner_train_indices=outer_train_idx,
|
||
strata=strata,
|
||
device=device,
|
||
n_trials=n_trials,
|
||
epochs_per_trial=epochs_per_trial,
|
||
patience=inner_patience,
|
||
batch_size=batch_size,
|
||
n_inner_folds=n_inner_folds,
|
||
use_mpnn=use_mpnn,
|
||
seed=seed + outer_fold,
|
||
study_path=study_path,
|
||
pretrain_state_dict=pretrain_state_dict,
|
||
pretrain_config=pretrain_config,
|
||
load_delivery_head=load_delivery_head,
|
||
)
|
||
|
||
# 保存最佳参数
|
||
with open(fold_dir / "best_params.json", "w") as f:
|
||
json.dump(best_params, f, indent=2)
|
||
|
||
with open(fold_dir / "epoch_mean.json", "w") as f:
|
||
json.dump({"epoch_mean": epoch_mean}, f)
|
||
|
||
# 外层训练(使用最优超参,固定 epoch 数,不 early-stop)
|
||
logger.info(f"\nTraining outer fold with best params, epochs={epoch_mean}...")
|
||
|
||
# 创建 DataLoader
|
||
train_subset = Subset(full_dataset, outer_train_idx.tolist())
|
||
test_subset = Subset(full_dataset, outer_test_idx.tolist())
|
||
|
||
train_loader = DataLoader(
|
||
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||
)
|
||
test_loader = DataLoader(
|
||
test_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||
)
|
||
|
||
# 计算类权重
|
||
class_weights = compute_class_weights_from_loader(train_loader)
|
||
|
||
# 创建模型
|
||
model = create_model(
|
||
d_model=best_params["d_model"],
|
||
num_heads=best_params["num_heads"],
|
||
n_attn_layers=best_params["n_attn_layers"],
|
||
fusion_strategy=best_params["fusion_strategy"],
|
||
head_hidden_dim=best_params["head_hidden_dim"],
|
||
dropout=best_params["dropout"],
|
||
use_mpnn=use_mpnn,
|
||
mpnn_device=device.type,
|
||
)
|
||
|
||
# 加载预训练权重
|
||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||
loaded = load_pretrain_weights_to_model(
|
||
model, pretrain_state_dict, best_params["d_model"],
|
||
pretrain_config, load_delivery_head
|
||
)
|
||
if loaded:
|
||
logger.info(f"Loaded pretrain weights for outer fold {outer_fold}")
|
||
|
||
# 训练(固定 epoch,不 early-stop)
|
||
train_result = train_fixed_epochs(
|
||
model=model,
|
||
train_loader=train_loader,
|
||
val_loader=None, # 外层不用验证集
|
||
device=device,
|
||
lr=best_params["lr"],
|
||
weight_decay=best_params["weight_decay"],
|
||
epochs=epoch_mean,
|
||
class_weights=class_weights,
|
||
use_cosine_annealing=True,
|
||
)
|
||
|
||
# 加载最终权重
|
||
model.load_state_dict(train_result["final_state"])
|
||
model = model.to(device)
|
||
|
||
# 保存模型
|
||
config = {
|
||
"d_model": best_params["d_model"],
|
||
"num_heads": best_params["num_heads"],
|
||
"n_attn_layers": best_params["n_attn_layers"],
|
||
"fusion_strategy": best_params["fusion_strategy"],
|
||
"head_hidden_dim": best_params["head_hidden_dim"],
|
||
"dropout": best_params["dropout"],
|
||
"use_mpnn": use_mpnn,
|
||
}
|
||
|
||
torch.save({
|
||
"model_state_dict": train_result["final_state"],
|
||
"config": config,
|
||
"epoch_mean": epoch_mean,
|
||
"best_params": best_params,
|
||
}, fold_dir / "model.pt")
|
||
|
||
# 保存训练历史
|
||
with open(fold_dir / "history.json", "w") as f:
|
||
json.dump(train_result["history"], f, indent=2)
|
||
|
||
# 在测试集上评估
|
||
logger.info("Evaluating on outer test set...")
|
||
test_metrics = evaluate_on_test(model, test_loader, device)
|
||
|
||
with open(fold_dir / "test_metrics.json", "w") as f:
|
||
json.dump(test_metrics, f, indent=2)
|
||
|
||
# 打印测试结果
|
||
logger.info(f"\nOuter Fold {outer_fold} Test Results:")
|
||
for task, metrics in test_metrics.items():
|
||
if "rmse" in metrics:
|
||
logger.info(f" {task}: RMSE={metrics['rmse']:.4f}, R²={metrics['r2']:.4f}")
|
||
elif "accuracy" in metrics:
|
||
logger.info(f" {task}: Acc={metrics['accuracy']:.4f}, F1={metrics['f1']:.4f}")
|
||
elif "kl_divergence" in metrics:
|
||
logger.info(f" {task}: KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}")
|
||
|
||
outer_results.append({
|
||
"fold": outer_fold,
|
||
"best_params": best_params,
|
||
"epoch_mean": epoch_mean,
|
||
"test_metrics": test_metrics,
|
||
})
|
||
|
||
# 汇总结果
|
||
logger.info("\n" + "=" * 60)
|
||
logger.info("NESTED CV COMPLETE")
|
||
logger.info("=" * 60)
|
||
|
||
# 计算汇总统计
|
||
summary = {"fold_results": outer_results}
|
||
|
||
# 对每个任务计算均值和标准差
|
||
tasks_with_metrics = {}
|
||
for result in outer_results:
|
||
for task, metrics in result["test_metrics"].items():
|
||
if task not in tasks_with_metrics:
|
||
tasks_with_metrics[task] = {k: [] for k in metrics.keys() if k != "n_samples"}
|
||
for k, v in metrics.items():
|
||
if k != "n_samples":
|
||
tasks_with_metrics[task][k].append(v)
|
||
|
||
summary["summary_stats"] = {}
|
||
for task, metrics_dict in tasks_with_metrics.items():
|
||
summary["summary_stats"][task] = {}
|
||
for metric_name, values in metrics_dict.items():
|
||
summary["summary_stats"][task][f"{metric_name}_mean"] = float(np.mean(values))
|
||
summary["summary_stats"][task][f"{metric_name}_std"] = float(np.std(values))
|
||
|
||
# 打印汇总
|
||
logger.info("\n[Summary Statistics]")
|
||
for task, stats in summary["summary_stats"].items():
|
||
if "rmse_mean" in stats:
|
||
logger.info(
|
||
f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, "
|
||
f"R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}"
|
||
)
|
||
elif "accuracy_mean" in stats:
|
||
logger.info(
|
||
f" {task}: Acc={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}, "
|
||
f"F1={stats['f1_mean']:.4f}±{stats['f1_std']:.4f}"
|
||
)
|
||
elif "kl_divergence_mean" in stats:
|
||
logger.info(
|
||
f" {task}: KL={stats['kl_divergence_mean']:.4f}±{stats['kl_divergence_std']:.4f}, "
|
||
f"JS={stats['js_divergence_mean']:.4f}±{stats['js_divergence_std']:.4f}"
|
||
)
|
||
|
||
# 保存汇总
|
||
with open(run_dir / "summary.json", "w") as f:
|
||
json.dump(summary, f, indent=2)
|
||
|
||
logger.success(f"\nAll results saved to {run_dir}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
app()
|
||
|