lnp_ml/lnp_ml/modeling/nested_cv_optuna.py

794 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
嵌套交叉验证 + Optuna 超参调优
外层 5-fold StratifiedKFold20% test / 80% train
内层 3-fold StratifiedKFold在 80% 上做 Optuna 超参搜索)
使用方法:
python -m lnp_ml.modeling.nested_cv_optuna
或通过 Makefile:
make nested_cv_tune DEVICE=cuda
"""
import json
import math
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from loguru import logger
import typer
try:
import optuna
from optuna.samplers import TPESampler
except ImportError:
optuna = None
TPESampler = None
from lnp_ml.config import MODELS_DIR, INTERIM_DATA_DIR
from lnp_ml.dataset import (
LNPDataset,
collate_fn,
process_dataframe,
TARGET_CLASSIFICATION_PDI,
TARGET_CLASSIFICATION_EE,
TARGET_TOXIC,
)
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
from lnp_ml.modeling.trainer_balanced import (
ClassWeights,
LossWeightsBalanced,
compute_class_weights_from_loader,
train_with_early_stopping,
train_fixed_epochs,
validate_balanced,
)
# MPNN ensemble 默认路径
DEFAULT_MPNN_ENSEMBLE_DIR = MODELS_DIR / "mpnn" / "all_amine_split_for_LiON"
app = typer.Typer()
# ============ CompositeStrata 复合分层标签 ============
def build_composite_strata(
df: pd.DataFrame,
min_stratum_count: int = 5,
) -> Tuple[np.ndarray, Dict]:
"""
构建复合分层标签toxic × PDI × EE
Args:
df: 处理后的 DataFrame
min_stratum_count: 每个 stratum 最少样本数,低于此值合并为 RARE
Returns:
(strata_array, strata_info)
- strata_array: 每个样本的 stratum 编码(整数)
- strata_info: 统计信息
"""
n = len(df)
strata_labels = []
for i in range(n):
# Toxic stratum
if TARGET_TOXIC in df.columns:
toxic_val = df[TARGET_TOXIC].iloc[i]
if pd.notna(toxic_val) and toxic_val >= 0:
toxic_str = str(int(toxic_val))
else:
toxic_str = "NA"
else:
toxic_str = "NA"
# PDI stratum
if all(col in df.columns for col in TARGET_CLASSIFICATION_PDI):
pdi_vals = df[TARGET_CLASSIFICATION_PDI].iloc[i].values
if pdi_vals.sum() > 0:
pdi_str = str(int(np.argmax(pdi_vals)))
else:
pdi_str = "NA"
else:
pdi_str = "NA"
# EE stratum
if all(col in df.columns for col in TARGET_CLASSIFICATION_EE):
ee_vals = df[TARGET_CLASSIFICATION_EE].iloc[i].values
if ee_vals.sum() > 0:
ee_str = str(int(np.argmax(ee_vals)))
else:
ee_str = "NA"
else:
ee_str = "NA"
strata_labels.append(f"T{toxic_str}|P{pdi_str}|E{ee_str}")
# 统计各 stratum 的样本数
unique_strata, counts = np.unique(strata_labels, return_counts=True)
strata_counts = dict(zip(unique_strata, counts))
# 将稀疏 strata 合并为 RARE
rare_strata = [s for s, c in strata_counts.items() if c < min_stratum_count]
final_labels = []
for label in strata_labels:
if label in rare_strata:
final_labels.append("RARE")
else:
final_labels.append(label)
# 编码为整数
unique_final, encoded = np.unique(final_labels, return_inverse=True)
strata_info = {
"original_strata_counts": strata_counts,
"rare_strata": rare_strata,
"final_strata": list(unique_final),
"final_strata_counts": dict(zip(*np.unique(final_labels, return_counts=True))),
"n_rare_merged": sum(strata_counts[s] for s in rare_strata) if rare_strata else 0,
}
logger.info(f"Built composite strata: {len(unique_final)} unique strata")
logger.info(f" Rare strata merged: {len(rare_strata)} types, {strata_info['n_rare_merged']} samples")
return encoded.astype(np.int64), strata_info
# ============ 模型创建 ============
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
"""自动查找 MPNN ensemble 的 model.pt 文件。"""
model_paths = sorted(base_dir.glob("cv_*/fold_*/model_*/model.pt"))
if not model_paths:
raise FileNotFoundError(f"No model.pt files found in {base_dir}")
return [str(p) for p in model_paths]
def create_model(
d_model: int = 256,
num_heads: int = 8,
n_attn_layers: int = 4,
fusion_strategy: str = "attention",
head_hidden_dim: int = 128,
dropout: float = 0.1,
use_mpnn: bool = False,
mpnn_device: str = "cpu",
) -> Union[LNPModel, LNPModelWithoutMPNN]:
"""创建模型"""
if use_mpnn:
ensemble_paths = find_mpnn_ensemble_paths()
return LNPModel(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
mpnn_ensemble_paths=ensemble_paths,
mpnn_device=mpnn_device,
)
else:
return LNPModelWithoutMPNN(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
)
# ============ 评估指标 ============
def evaluate_on_test(
model: torch.nn.Module,
test_loader: DataLoader,
device: torch.device,
) -> Dict:
"""在测试集上评估模型"""
from scipy.special import rel_entr
from sklearn.metrics import (
mean_squared_error,
mean_absolute_error,
r2_score,
accuracy_score,
precision_score,
recall_score,
f1_score,
)
model.eval()
preds = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
}
targets = {
"size": [], "delivery": [], "pdi": [], "ee": [], "toxic": [], "biodist": []
}
with torch.no_grad():
for batch in test_loader:
smiles = batch["smiles"]
tabular = {k: v.to(device) for k, v in batch["tabular"].items()}
tgts = batch["targets"]
masks = batch["mask"]
outputs = model(smiles, tabular)
# 收集预测和真实值
for task in ["size", "delivery"]:
if task in masks and masks[task].any():
m = masks[task]
key = task if task == "size" else "delivery"
preds[task].extend(outputs[key].squeeze(-1)[m].cpu().numpy().tolist())
targets[task].extend(tgts[key][m].cpu().numpy().tolist())
for task in ["pdi", "ee", "toxic"]:
if task in masks and masks[task].any():
m = masks[task]
preds[task].extend(outputs[task][m].argmax(dim=-1).cpu().numpy().tolist())
targets[task].extend(tgts[task][m].cpu().numpy().tolist())
if "biodist" in masks and masks["biodist"].any():
m = masks["biodist"]
preds["biodist"].extend(outputs["biodist"][m].cpu().numpy().tolist())
targets["biodist"].extend(tgts["biodist"][m].cpu().numpy().tolist())
# 计算指标
results = {}
# 回归任务
for task in ["size", "delivery"]:
if preds[task]:
p = np.array(preds[task])
t = np.array(targets[task])
results[task] = {
"n_samples": len(p),
"mse": float(mean_squared_error(t, p)),
"rmse": float(np.sqrt(mean_squared_error(t, p))),
"mae": float(mean_absolute_error(t, p)),
"r2": float(r2_score(t, p)),
}
# 分类任务
for task in ["pdi", "ee", "toxic"]:
if preds[task]:
p = np.array(preds[task])
t = np.array(targets[task])
results[task] = {
"n_samples": len(p),
"accuracy": float(accuracy_score(t, p)),
"precision": float(precision_score(t, p, average="macro", zero_division=0)),
"recall": float(recall_score(t, p, average="macro", zero_division=0)),
"f1": float(f1_score(t, p, average="macro", zero_division=0)),
}
# 分布任务
if preds["biodist"]:
p = np.array(preds["biodist"])
t = np.array(targets["biodist"])
def kl_divergence(p_arr, q_arr, eps=1e-10):
p_arr = np.clip(p_arr, eps, 1.0)
q_arr = np.clip(q_arr, eps, 1.0)
return float(np.sum(rel_entr(p_arr, q_arr), axis=-1).mean())
def js_divergence(p_arr, q_arr, eps=1e-10):
p_arr = np.clip(p_arr, eps, 1.0)
q_arr = np.clip(q_arr, eps, 1.0)
m = 0.5 * (p_arr + q_arr)
return float(0.5 * (np.sum(rel_entr(p_arr, m), axis=-1) + np.sum(rel_entr(q_arr, m), axis=-1)).mean())
results["biodist"] = {
"n_samples": len(p),
"kl_divergence": kl_divergence(t, p),
"js_divergence": js_divergence(t, p),
}
return results
# ============ 预训练权重加载 ============
def load_pretrain_weights_to_model(
model: Union[LNPModel, LNPModelWithoutMPNN],
pretrain_state_dict: Dict,
d_model: int,
pretrain_config: Dict,
load_delivery_head: bool = True,
) -> bool:
"""
加载预训练权重到模型。
Returns:
是否成功加载
"""
if pretrain_config.get("d_model") != d_model:
logger.warning(
f"d_model mismatch: pretrain={pretrain_config.get('d_model')}, "
f"current={d_model}. Skipping pretrain loading."
)
return False
model.load_pretrain_weights(
pretrain_state_dict=pretrain_state_dict,
load_delivery_head=load_delivery_head,
strict=False,
)
return True
# ============ 内层 Optuna 调参 ============
def run_inner_optuna(
full_dataset: LNPDataset,
inner_train_indices: np.ndarray,
strata: np.ndarray,
device: torch.device,
n_trials: int = 20,
epochs_per_trial: int = 30,
patience: int = 10,
batch_size: int = 32,
n_inner_folds: int = 3,
use_mpnn: bool = False,
seed: int = 42,
study_path: Optional[Path] = None,
pretrain_state_dict: Optional[Dict] = None,
pretrain_config: Optional[Dict] = None,
load_delivery_head: bool = True,
) -> Tuple[Dict, int, optuna.Study]:
"""
在内层数据上运行 Optuna 超参搜索。
Args:
full_dataset: 完整数据集
inner_train_indices: 内层训练数据的索引(相对于 full_dataset
strata: 每个样本的分层标签
device: 设备
n_trials: Optuna 试验数
epochs_per_trial: 每个试验的最大 epoch
patience: 早停耐心值
batch_size: 批次大小
n_inner_folds: 内层折数
use_mpnn: 是否使用 MPNN
seed: 随机种子
study_path: 可选的 study 持久化路径
pretrain_state_dict: 预训练权重
pretrain_config: 预训练配置
load_delivery_head: 是否加载 delivery head 权重
Returns:
(best_params, epoch_mean, study)
"""
if optuna is None:
raise ImportError("Optuna not installed. Run: pip install optuna")
inner_strata = strata[inner_train_indices]
# 固定架构参数(与预训练一致,确保权重完整加载)
_cfg = pretrain_config or {}
fixed_d_model = _cfg.get("d_model", 256)
fixed_num_heads = _cfg.get("num_heads", 8)
fixed_n_attn_layers = _cfg.get("n_attn_layers", 4)
fixed_fusion_strategy = _cfg.get("fusion_strategy", "attention")
fixed_head_hidden_dim = _cfg.get("head_hidden_dim", 128)
logger.info(
f"Fixed architecture params: d_model={fixed_d_model}, num_heads={fixed_num_heads}, "
f"n_attn_layers={fixed_n_attn_layers}, fusion={fixed_fusion_strategy}, "
f"head_hidden_dim={fixed_head_hidden_dim}"
)
def objective(trial: optuna.Trial) -> float:
d_model = fixed_d_model
num_heads = fixed_num_heads
n_attn_layers = fixed_n_attn_layers
fusion_strategy = fixed_fusion_strategy
head_hidden_dim = fixed_head_hidden_dim
# 搜索训练超参数
dropout = trial.suggest_float("dropout", 0.1, 0.5)
lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True)
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
# 内层 3-fold CV
inner_cv = StratifiedKFold(
n_splits=n_inner_folds, shuffle=True, random_state=seed
)
fold_val_losses = []
fold_best_epochs = []
for inner_fold, (inner_train_idx, inner_val_idx) in enumerate(
inner_cv.split(inner_train_indices, inner_strata)
):
# 获取实际的数据集索引
actual_train_idx = inner_train_indices[inner_train_idx]
actual_val_idx = inner_train_indices[inner_val_idx]
# 创建 DataLoader
train_subset = Subset(full_dataset, actual_train_idx.tolist())
val_subset = Subset(full_dataset, actual_val_idx.tolist())
train_loader = DataLoader(
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
val_loader = DataLoader(
val_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 计算类权重
class_weights = compute_class_weights_from_loader(train_loader)
# 创建模型
model = create_model(
d_model=d_model,
num_heads=num_heads,
n_attn_layers=n_attn_layers,
fusion_strategy=fusion_strategy,
head_hidden_dim=head_hidden_dim,
dropout=dropout,
use_mpnn=use_mpnn,
mpnn_device=device.type,
)
# 加载预训练权重
if pretrain_state_dict is not None and pretrain_config is not None:
load_pretrain_weights_to_model(
model, pretrain_state_dict, d_model, pretrain_config, load_delivery_head
)
# 训练(带早停)
result = train_with_early_stopping(
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
lr=lr,
weight_decay=weight_decay,
epochs=epochs_per_trial,
patience=patience,
class_weights=class_weights,
)
fold_val_losses.append(result["best_val_loss"])
fold_best_epochs.append(result["best_epoch"])
# 记录 epoch_mean 到 trial
epoch_mean = int(round(np.mean(fold_best_epochs)))
trial.set_user_attr("epoch_mean", epoch_mean)
trial.set_user_attr("fold_best_epochs", fold_best_epochs)
return np.mean(fold_val_losses)
# 创建 study
storage = None
if study_path is not None:
storage = f"sqlite:///{study_path}"
study = optuna.create_study(
direction="minimize",
sampler=TPESampler(seed=seed),
storage=storage,
study_name="inner_optuna",
load_if_exists=True,
)
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
best_params = dict(study.best_trial.params)
best_params.update({
"d_model": fixed_d_model,
"num_heads": fixed_num_heads,
"n_attn_layers": fixed_n_attn_layers,
"fusion_strategy": fixed_fusion_strategy,
"head_hidden_dim": fixed_head_hidden_dim,
})
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
logger.info(f"Best trial: {study.best_trial.number}")
logger.info(f"Best val_loss: {study.best_trial.value:.4f}")
logger.info(f"Best params: {best_params}")
logger.info(f"Epoch mean: {epoch_mean}")
return best_params, epoch_mean, study
# ============ 主流程 ============
@app.command()
def main(
input_path: Path = INTERIM_DATA_DIR / "internal.csv",
output_dir: Path = MODELS_DIR / "nested_cv",
# CV 参数
n_outer_folds: int = 5,
n_inner_folds: int = 3,
min_stratum_count: int = 5,
seed: int = 42,
# Optuna 参数
n_trials: int = 20,
epochs_per_trial: int = 30,
inner_patience: int = 10,
# 训练参数
batch_size: int = 32,
# 预训练权重
init_from_pretrain: Optional[Path] = None,
load_delivery_head: bool = True,
# MPNN
use_mpnn: bool = False,
# 设备
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""
嵌套交叉验证 + Optuna 超参调优。
外层 5-fold20% test / 80% train内层 3-fold Optuna 调参。
外层训练不使用 early-stoppingepoch 数使用内层 best trial 的 epoch_mean。
使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重。
"""
if optuna is None:
logger.error("Optuna not installed. Run: pip install optuna")
raise typer.Exit(1)
logger.info(f"Using device: {device}")
device = torch.device(device)
# 加载预训练权重(如果指定)
pretrain_state_dict = None
pretrain_config = None
if init_from_pretrain is not None:
if init_from_pretrain.exists():
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
pretrain_state_dict = checkpoint["model_state_dict"]
pretrain_config = checkpoint.get("config", {})
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")
else:
logger.warning(f"Pretrain checkpoint not found: {init_from_pretrain}, skipping")
# 创建输出目录(带时间戳)
run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = output_dir / run_name
run_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Output directory: {run_dir}")
# 加载数据
logger.info(f"Loading data from {input_path}")
df = pd.read_csv(input_path)
logger.info(f"Loaded {len(df)} samples")
# 处理数据
logger.info("Processing dataframe...")
df = process_dataframe(df)
# 构建复合分层标签
logger.info("Building composite strata...")
strata, strata_info = build_composite_strata(df, min_stratum_count)
# 保存 strata 信息
with open(run_dir / "strata_info.json", "w") as f:
json.dump(strata_info, f, indent=2, default=str)
# 创建完整数据集
full_dataset = LNPDataset(df)
n_samples = len(full_dataset)
# 外层 CV
outer_cv = StratifiedKFold(
n_splits=n_outer_folds, shuffle=True, random_state=seed
)
outer_results = []
for outer_fold, (outer_train_idx, outer_test_idx) in enumerate(
outer_cv.split(np.arange(n_samples), strata)
):
logger.info(f"\n{'='*60}")
logger.info(f"OUTER FOLD {outer_fold}")
logger.info(f"{'='*60}")
logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}")
fold_dir = run_dir / f"outer_fold_{outer_fold}"
fold_dir.mkdir(parents=True, exist_ok=True)
# 保存 split indices
splits = {
"outer_train_idx": outer_train_idx.tolist(),
"outer_test_idx": outer_test_idx.tolist(),
}
with open(fold_dir / "splits.json", "w") as f:
json.dump(splits, f)
# 内层 Optuna 调参
logger.info(f"\nRunning inner Optuna with {n_trials} trials...")
study_path = fold_dir / "optuna_study.sqlite3"
best_params, epoch_mean, study = run_inner_optuna(
full_dataset=full_dataset,
inner_train_indices=outer_train_idx,
strata=strata,
device=device,
n_trials=n_trials,
epochs_per_trial=epochs_per_trial,
patience=inner_patience,
batch_size=batch_size,
n_inner_folds=n_inner_folds,
use_mpnn=use_mpnn,
seed=seed + outer_fold,
study_path=study_path,
pretrain_state_dict=pretrain_state_dict,
pretrain_config=pretrain_config,
load_delivery_head=load_delivery_head,
)
# 保存最佳参数
with open(fold_dir / "best_params.json", "w") as f:
json.dump(best_params, f, indent=2)
with open(fold_dir / "epoch_mean.json", "w") as f:
json.dump({"epoch_mean": epoch_mean}, f)
# 外层训练(使用最优超参,固定 epoch 数,不 early-stop
logger.info(f"\nTraining outer fold with best params, epochs={epoch_mean}...")
# 创建 DataLoader
train_subset = Subset(full_dataset, outer_train_idx.tolist())
test_subset = Subset(full_dataset, outer_test_idx.tolist())
train_loader = DataLoader(
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
test_loader = DataLoader(
test_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
# 计算类权重
class_weights = compute_class_weights_from_loader(train_loader)
# 创建模型
model = create_model(
d_model=best_params["d_model"],
num_heads=best_params["num_heads"],
n_attn_layers=best_params["n_attn_layers"],
fusion_strategy=best_params["fusion_strategy"],
head_hidden_dim=best_params["head_hidden_dim"],
dropout=best_params["dropout"],
use_mpnn=use_mpnn,
mpnn_device=device.type,
)
# 加载预训练权重
if pretrain_state_dict is not None and pretrain_config is not None:
loaded = load_pretrain_weights_to_model(
model, pretrain_state_dict, best_params["d_model"],
pretrain_config, load_delivery_head
)
if loaded:
logger.info(f"Loaded pretrain weights for outer fold {outer_fold}")
# 训练(固定 epoch不 early-stop
train_result = train_fixed_epochs(
model=model,
train_loader=train_loader,
val_loader=None, # 外层不用验证集
device=device,
lr=best_params["lr"],
weight_decay=best_params["weight_decay"],
epochs=epoch_mean,
class_weights=class_weights,
use_cosine_annealing=True,
)
# 加载最终权重
model.load_state_dict(train_result["final_state"])
model = model.to(device)
# 保存模型
config = {
"d_model": best_params["d_model"],
"num_heads": best_params["num_heads"],
"n_attn_layers": best_params["n_attn_layers"],
"fusion_strategy": best_params["fusion_strategy"],
"head_hidden_dim": best_params["head_hidden_dim"],
"dropout": best_params["dropout"],
"use_mpnn": use_mpnn,
}
torch.save({
"model_state_dict": train_result["final_state"],
"config": config,
"epoch_mean": epoch_mean,
"best_params": best_params,
}, fold_dir / "model.pt")
# 保存训练历史
with open(fold_dir / "history.json", "w") as f:
json.dump(train_result["history"], f, indent=2)
# 在测试集上评估
logger.info("Evaluating on outer test set...")
test_metrics = evaluate_on_test(model, test_loader, device)
with open(fold_dir / "test_metrics.json", "w") as f:
json.dump(test_metrics, f, indent=2)
# 打印测试结果
logger.info(f"\nOuter Fold {outer_fold} Test Results:")
for task, metrics in test_metrics.items():
if "rmse" in metrics:
logger.info(f" {task}: RMSE={metrics['rmse']:.4f}, R²={metrics['r2']:.4f}")
elif "accuracy" in metrics:
logger.info(f" {task}: Acc={metrics['accuracy']:.4f}, F1={metrics['f1']:.4f}")
elif "kl_divergence" in metrics:
logger.info(f" {task}: KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}")
outer_results.append({
"fold": outer_fold,
"best_params": best_params,
"epoch_mean": epoch_mean,
"test_metrics": test_metrics,
})
# 汇总结果
logger.info("\n" + "=" * 60)
logger.info("NESTED CV COMPLETE")
logger.info("=" * 60)
# 计算汇总统计
summary = {"fold_results": outer_results}
# 对每个任务计算均值和标准差
tasks_with_metrics = {}
for result in outer_results:
for task, metrics in result["test_metrics"].items():
if task not in tasks_with_metrics:
tasks_with_metrics[task] = {k: [] for k in metrics.keys() if k != "n_samples"}
for k, v in metrics.items():
if k != "n_samples":
tasks_with_metrics[task][k].append(v)
summary["summary_stats"] = {}
for task, metrics_dict in tasks_with_metrics.items():
summary["summary_stats"][task] = {}
for metric_name, values in metrics_dict.items():
summary["summary_stats"][task][f"{metric_name}_mean"] = float(np.mean(values))
summary["summary_stats"][task][f"{metric_name}_std"] = float(np.std(values))
# 打印汇总
logger.info("\n[Summary Statistics]")
for task, stats in summary["summary_stats"].items():
if "rmse_mean" in stats:
logger.info(
f" {task}: RMSE={stats['rmse_mean']:.4f}±{stats['rmse_std']:.4f}, "
f"R²={stats['r2_mean']:.4f}±{stats['r2_std']:.4f}"
)
elif "accuracy_mean" in stats:
logger.info(
f" {task}: Acc={stats['accuracy_mean']:.4f}±{stats['accuracy_std']:.4f}, "
f"F1={stats['f1_mean']:.4f}±{stats['f1_std']:.4f}"
)
elif "kl_divergence_mean" in stats:
logger.info(
f" {task}: KL={stats['kl_divergence_mean']:.4f}±{stats['kl_divergence_std']:.4f}, "
f"JS={stats['js_divergence_mean']:.4f}±{stats['js_divergence_std']:.4f}"
)
# 保存汇总
with open(run_dir / "summary.json", "w") as f:
json.dump(summary, f, indent=2)
logger.success(f"\nAll results saved to {run_dir}")
if __name__ == "__main__":
app()