mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
增加5-fold训练的并行功能
This commit is contained in:
parent
48de4d5d37
commit
13b357ce05
5
Makefile
5
Makefile
@ -15,6 +15,7 @@ EPOCHS_PER_TRIAL_FLAG = $(if $(EPOCHS_PER_TRIAL),--epochs-per-trial $(EPOCHS_PER
|
|||||||
MIN_STRATUM_FLAG = $(if $(MIN_STRATUM_COUNT),--min-stratum-count $(MIN_STRATUM_COUNT),)
|
MIN_STRATUM_FLAG = $(if $(MIN_STRATUM_COUNT),--min-stratum-count $(MIN_STRATUM_COUNT),)
|
||||||
OUTPUT_DIR_FLAG = $(if $(OUTPUT_DIR),--output-dir $(OUTPUT_DIR),)
|
OUTPUT_DIR_FLAG = $(if $(OUTPUT_DIR),--output-dir $(OUTPUT_DIR),)
|
||||||
USE_SWA_FLAG = $(if $(USE_SWA),--use-swa,)
|
USE_SWA_FLAG = $(if $(USE_SWA),--use-swa,)
|
||||||
|
PARALLEL_FLAG = $(if $(PARALLEL),--parallel,)
|
||||||
INIT_PRETRAIN_FLAG = $(if $(NO_PRETRAIN),,--init-from-pretrain $(or $(INIT_PRETRAIN),models/pretrain_delivery.pt))
|
INIT_PRETRAIN_FLAG = $(if $(NO_PRETRAIN),,--init-from-pretrain $(or $(INIT_PRETRAIN),models/pretrain_delivery.pt))
|
||||||
|
|
||||||
#################################################################################
|
#################################################################################
|
||||||
@ -91,10 +92,12 @@ benchmark: requirements
|
|||||||
# INIT_PRETRAIN 预训练权重路径 (默认: models/pretrain_delivery.pt)
|
# INIT_PRETRAIN 预训练权重路径 (默认: models/pretrain_delivery.pt)
|
||||||
# NO_PRETRAIN=1 禁用预训练权重
|
# NO_PRETRAIN=1 禁用预训练权重
|
||||||
# USE_SWA=1 启用 SWA (final train 阶段)
|
# USE_SWA=1 启用 SWA (final train 阶段)
|
||||||
|
# PARALLEL=1 并行运行外层 fold (nested CV 阶段,需足够 GPU 显存)
|
||||||
#
|
#
|
||||||
# 使用示例:
|
# 使用示例:
|
||||||
# make pretrain
|
# make pretrain
|
||||||
# make train DEVICE=cuda N_TRIALS=30 USE_SWA=1 INIT_PRETRAIN=models/pretrain_delivery.pt
|
# make train DEVICE=cuda N_TRIALS=30 USE_SWA=1 INIT_PRETRAIN=models/pretrain_delivery.pt
|
||||||
|
# make train DEVICE=cuda PARALLEL=1
|
||||||
|
|
||||||
|
|
||||||
## Pretrain on external data (delivery only)
|
## Pretrain on external data (delivery only)
|
||||||
@ -109,7 +112,7 @@ pretrain: requirements
|
|||||||
train: requirements
|
train: requirements
|
||||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.nested_cv_optuna \
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.nested_cv_optuna \
|
||||||
$(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \
|
$(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \
|
||||||
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG)
|
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(PARALLEL_FLAG)
|
||||||
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.final_train_optuna_cv \
|
$(PYTHON_INTERPRETER) -m lnp_ml.modeling.final_train_optuna_cv \
|
||||||
$(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \
|
$(DEVICE_FLAG) $(MPNN_FLAG) $(SEED_FLAG) $(INIT_PRETRAIN_FLAG) \
|
||||||
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG)
|
$(N_TRIALS_FLAG) $(EPOCHS_PER_TRIAL_FLAG) $(MIN_STRATUM_FLAG) $(OUTPUT_DIR_FLAG) $(USE_SWA_FLAG)
|
||||||
|
|||||||
@ -503,6 +503,173 @@ def run_inner_optuna(
|
|||||||
return best_params, epoch_mean, study
|
return best_params, epoch_mean, study
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 单 fold 执行(可跨进程调用) ============
|
||||||
|
|
||||||
|
def _run_single_outer_fold(
|
||||||
|
outer_fold: int,
|
||||||
|
outer_train_idx: np.ndarray,
|
||||||
|
outer_test_idx: np.ndarray,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
strata: np.ndarray,
|
||||||
|
fold_dir: Path,
|
||||||
|
n_trials: int,
|
||||||
|
epochs_per_trial: int,
|
||||||
|
inner_patience: int,
|
||||||
|
batch_size: int,
|
||||||
|
n_inner_folds: int,
|
||||||
|
use_mpnn: bool,
|
||||||
|
seed: int,
|
||||||
|
pretrain_state_dict: Optional[Dict],
|
||||||
|
pretrain_config: Optional[Dict],
|
||||||
|
load_delivery_head: bool,
|
||||||
|
device_str: str,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
执行单个外层 fold 的完整流程(内层调参 + 外层训练 + 评估)。
|
||||||
|
|
||||||
|
所有参数均为可序列化类型,以支持 spawn 多进程。
|
||||||
|
"""
|
||||||
|
device = torch.device(device_str)
|
||||||
|
fold_dir = Path(fold_dir)
|
||||||
|
fold_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
full_dataset = LNPDataset(df)
|
||||||
|
|
||||||
|
logger.info(f"\n{'='*60}")
|
||||||
|
logger.info(f"OUTER FOLD {outer_fold}")
|
||||||
|
logger.info(f"{'='*60}")
|
||||||
|
logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}")
|
||||||
|
|
||||||
|
# 保存 split indices
|
||||||
|
with open(fold_dir / "splits.json", "w") as f:
|
||||||
|
json.dump({
|
||||||
|
"outer_train_idx": outer_train_idx.tolist(),
|
||||||
|
"outer_test_idx": outer_test_idx.tolist(),
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
# 内层 Optuna 调参
|
||||||
|
logger.info(f"\nRunning inner Optuna with {n_trials} trials...")
|
||||||
|
study_path = fold_dir / "optuna_study.sqlite3"
|
||||||
|
|
||||||
|
best_params, epoch_mean, study = run_inner_optuna(
|
||||||
|
full_dataset=full_dataset,
|
||||||
|
inner_train_indices=outer_train_idx,
|
||||||
|
strata=strata,
|
||||||
|
device=device,
|
||||||
|
n_trials=n_trials,
|
||||||
|
epochs_per_trial=epochs_per_trial,
|
||||||
|
patience=inner_patience,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_inner_folds=n_inner_folds,
|
||||||
|
use_mpnn=use_mpnn,
|
||||||
|
seed=seed + outer_fold,
|
||||||
|
study_path=study_path,
|
||||||
|
pretrain_state_dict=pretrain_state_dict,
|
||||||
|
pretrain_config=pretrain_config,
|
||||||
|
load_delivery_head=load_delivery_head,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存最佳参数
|
||||||
|
with open(fold_dir / "best_params.json", "w") as f:
|
||||||
|
json.dump(best_params, f, indent=2)
|
||||||
|
|
||||||
|
with open(fold_dir / "epoch_mean.json", "w") as f:
|
||||||
|
json.dump({"epoch_mean": epoch_mean}, f)
|
||||||
|
|
||||||
|
# 外层训练(使用最优超参,固定 epoch 数,不 early-stop)
|
||||||
|
logger.info(f"\nTraining outer fold with best params, epochs={epoch_mean}...")
|
||||||
|
|
||||||
|
train_subset = Subset(full_dataset, outer_train_idx.tolist())
|
||||||
|
test_subset = Subset(full_dataset, outer_test_idx.tolist())
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
class_weights = compute_class_weights_from_loader(train_loader)
|
||||||
|
|
||||||
|
model = create_model(
|
||||||
|
d_model=best_params["d_model"],
|
||||||
|
num_heads=best_params["num_heads"],
|
||||||
|
n_attn_layers=best_params["n_attn_layers"],
|
||||||
|
fusion_strategy=best_params["fusion_strategy"],
|
||||||
|
head_hidden_dim=best_params["head_hidden_dim"],
|
||||||
|
dropout=best_params["dropout"],
|
||||||
|
use_mpnn=use_mpnn,
|
||||||
|
mpnn_device=device.type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if pretrain_state_dict is not None and pretrain_config is not None:
|
||||||
|
loaded = load_pretrain_weights_to_model(
|
||||||
|
model, pretrain_state_dict, best_params["d_model"],
|
||||||
|
pretrain_config, load_delivery_head
|
||||||
|
)
|
||||||
|
if loaded:
|
||||||
|
logger.info(f"Loaded pretrain weights for outer fold {outer_fold}")
|
||||||
|
|
||||||
|
train_result = train_fixed_epochs(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=None,
|
||||||
|
device=device,
|
||||||
|
lr=best_params["lr"],
|
||||||
|
weight_decay=best_params["weight_decay"],
|
||||||
|
epochs=epoch_mean,
|
||||||
|
class_weights=class_weights,
|
||||||
|
use_cosine_annealing=True,
|
||||||
|
backbone_lr_ratio=best_params.get("backbone_lr_ratio", 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
model.load_state_dict(train_result["final_state"])
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"d_model": best_params["d_model"],
|
||||||
|
"num_heads": best_params["num_heads"],
|
||||||
|
"n_attn_layers": best_params["n_attn_layers"],
|
||||||
|
"fusion_strategy": best_params["fusion_strategy"],
|
||||||
|
"head_hidden_dim": best_params["head_hidden_dim"],
|
||||||
|
"dropout": best_params["dropout"],
|
||||||
|
"use_mpnn": use_mpnn,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save({
|
||||||
|
"model_state_dict": train_result["final_state"],
|
||||||
|
"config": config,
|
||||||
|
"epoch_mean": epoch_mean,
|
||||||
|
"best_params": best_params,
|
||||||
|
}, fold_dir / "model.pt")
|
||||||
|
|
||||||
|
with open(fold_dir / "history.json", "w") as f:
|
||||||
|
json.dump(train_result["history"], f, indent=2)
|
||||||
|
|
||||||
|
# 在测试集上评估
|
||||||
|
logger.info("Evaluating on outer test set...")
|
||||||
|
test_metrics = evaluate_on_test(model, test_loader, device)
|
||||||
|
|
||||||
|
with open(fold_dir / "test_metrics.json", "w") as f:
|
||||||
|
json.dump(test_metrics, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"\nOuter Fold {outer_fold} Test Results:")
|
||||||
|
for task, metrics in test_metrics.items():
|
||||||
|
if "rmse" in metrics:
|
||||||
|
logger.info(f" {task}: RMSE={metrics['rmse']:.4f}, R²={metrics['r2']:.4f}")
|
||||||
|
elif "accuracy" in metrics:
|
||||||
|
logger.info(f" {task}: Acc={metrics['accuracy']:.4f}, F1={metrics['f1']:.4f}")
|
||||||
|
elif "kl_divergence" in metrics:
|
||||||
|
logger.info(f" {task}: KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"fold": outer_fold,
|
||||||
|
"best_params": best_params,
|
||||||
|
"epoch_mean": epoch_mean,
|
||||||
|
"test_metrics": test_metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ============ 主流程 ============
|
# ============ 主流程 ============
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
@ -525,6 +692,8 @@ def main(
|
|||||||
load_delivery_head: bool = False,
|
load_delivery_head: bool = False,
|
||||||
# MPNN
|
# MPNN
|
||||||
use_mpnn: bool = False,
|
use_mpnn: bool = False,
|
||||||
|
# 并行
|
||||||
|
parallel: bool = False,
|
||||||
# 设备
|
# 设备
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||||||
):
|
):
|
||||||
@ -535,6 +704,7 @@ def main(
|
|||||||
外层训练不使用 early-stopping,epoch 数使用内层 best trial 的 epoch_mean。
|
外层训练不使用 early-stopping,epoch 数使用内层 best trial 的 epoch_mean。
|
||||||
|
|
||||||
使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重。
|
使用 --init-from-pretrain 从预训练 checkpoint 初始化模型权重。
|
||||||
|
使用 --parallel 同时运行所有外层 fold(需要足够 GPU 显存)。
|
||||||
"""
|
"""
|
||||||
if optuna is None:
|
if optuna is None:
|
||||||
logger.error("Optuna not installed. Run: pip install optuna")
|
logger.error("Optuna not installed. Run: pip install optuna")
|
||||||
@ -579,166 +749,55 @@ def main(
|
|||||||
with open(run_dir / "strata_info.json", "w") as f:
|
with open(run_dir / "strata_info.json", "w") as f:
|
||||||
json.dump(strata_info, f, indent=2, default=str)
|
json.dump(strata_info, f, indent=2, default=str)
|
||||||
|
|
||||||
# 创建完整数据集
|
# 创建完整数据集(仅用于获取样本数做 split)
|
||||||
full_dataset = LNPDataset(df)
|
n_samples = len(LNPDataset(df))
|
||||||
n_samples = len(full_dataset)
|
|
||||||
|
|
||||||
# 外层 CV
|
# 外层 CV split
|
||||||
outer_cv = StratifiedKFold(
|
outer_cv = StratifiedKFold(
|
||||||
n_splits=n_outer_folds, shuffle=True, random_state=seed
|
n_splits=n_outer_folds, shuffle=True, random_state=seed
|
||||||
)
|
)
|
||||||
|
|
||||||
outer_results = []
|
device_str = str(device)
|
||||||
|
fold_args = []
|
||||||
for outer_fold, (outer_train_idx, outer_test_idx) in enumerate(
|
for outer_fold, (outer_train_idx, outer_test_idx) in enumerate(
|
||||||
outer_cv.split(np.arange(n_samples), strata)
|
outer_cv.split(np.arange(n_samples), strata)
|
||||||
):
|
):
|
||||||
logger.info(f"\n{'='*60}")
|
fold_args.append(dict(
|
||||||
logger.info(f"OUTER FOLD {outer_fold}")
|
outer_fold=outer_fold,
|
||||||
logger.info(f"{'='*60}")
|
outer_train_idx=outer_train_idx,
|
||||||
logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}")
|
outer_test_idx=outer_test_idx,
|
||||||
|
df=df,
|
||||||
fold_dir = run_dir / f"outer_fold_{outer_fold}"
|
|
||||||
fold_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 保存 split indices
|
|
||||||
splits = {
|
|
||||||
"outer_train_idx": outer_train_idx.tolist(),
|
|
||||||
"outer_test_idx": outer_test_idx.tolist(),
|
|
||||||
}
|
|
||||||
with open(fold_dir / "splits.json", "w") as f:
|
|
||||||
json.dump(splits, f)
|
|
||||||
|
|
||||||
# 内层 Optuna 调参
|
|
||||||
logger.info(f"\nRunning inner Optuna with {n_trials} trials...")
|
|
||||||
study_path = fold_dir / "optuna_study.sqlite3"
|
|
||||||
|
|
||||||
best_params, epoch_mean, study = run_inner_optuna(
|
|
||||||
full_dataset=full_dataset,
|
|
||||||
inner_train_indices=outer_train_idx,
|
|
||||||
strata=strata,
|
strata=strata,
|
||||||
device=device,
|
fold_dir=run_dir / f"outer_fold_{outer_fold}",
|
||||||
n_trials=n_trials,
|
n_trials=n_trials,
|
||||||
epochs_per_trial=epochs_per_trial,
|
epochs_per_trial=epochs_per_trial,
|
||||||
patience=inner_patience,
|
inner_patience=inner_patience,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_inner_folds=n_inner_folds,
|
n_inner_folds=n_inner_folds,
|
||||||
use_mpnn=use_mpnn,
|
use_mpnn=use_mpnn,
|
||||||
seed=seed + outer_fold,
|
seed=seed,
|
||||||
study_path=study_path,
|
|
||||||
pretrain_state_dict=pretrain_state_dict,
|
pretrain_state_dict=pretrain_state_dict,
|
||||||
pretrain_config=pretrain_config,
|
pretrain_config=pretrain_config,
|
||||||
load_delivery_head=load_delivery_head,
|
load_delivery_head=load_delivery_head,
|
||||||
)
|
device_str=device_str,
|
||||||
|
))
|
||||||
|
|
||||||
|
if parallel:
|
||||||
|
import multiprocessing as mp
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
# 保存最佳参数
|
ctx = mp.get_context("spawn")
|
||||||
with open(fold_dir / "best_params.json", "w") as f:
|
logger.info(f"Running {n_outer_folds} outer folds in PARALLEL (spawn)")
|
||||||
json.dump(best_params, f, indent=2)
|
with ProcessPoolExecutor(max_workers=n_outer_folds, mp_context=ctx) as executor:
|
||||||
|
futures = [executor.submit(_run_single_outer_fold, **args) for args in fold_args]
|
||||||
with open(fold_dir / "epoch_mean.json", "w") as f:
|
outer_results = [f.result() for f in futures]
|
||||||
json.dump({"epoch_mean": epoch_mean}, f)
|
outer_results.sort(key=lambda r: r["fold"])
|
||||||
|
else:
|
||||||
# 外层训练(使用最优超参,固定 epoch 数,不 early-stop)
|
logger.info(f"Running {n_outer_folds} outer folds SEQUENTIALLY")
|
||||||
logger.info(f"\nTraining outer fold with best params, epochs={epoch_mean}...")
|
outer_results = []
|
||||||
|
for args in fold_args:
|
||||||
# 创建 DataLoader
|
result = _run_single_outer_fold(**args)
|
||||||
train_subset = Subset(full_dataset, outer_train_idx.tolist())
|
outer_results.append(result)
|
||||||
test_subset = Subset(full_dataset, outer_test_idx.tolist())
|
|
||||||
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
|
|
||||||
)
|
|
||||||
test_loader = DataLoader(
|
|
||||||
test_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
# 计算类权重
|
|
||||||
class_weights = compute_class_weights_from_loader(train_loader)
|
|
||||||
|
|
||||||
# 创建模型
|
|
||||||
model = create_model(
|
|
||||||
d_model=best_params["d_model"],
|
|
||||||
num_heads=best_params["num_heads"],
|
|
||||||
n_attn_layers=best_params["n_attn_layers"],
|
|
||||||
fusion_strategy=best_params["fusion_strategy"],
|
|
||||||
head_hidden_dim=best_params["head_hidden_dim"],
|
|
||||||
dropout=best_params["dropout"],
|
|
||||||
use_mpnn=use_mpnn,
|
|
||||||
mpnn_device=device.type,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载预训练权重
|
|
||||||
if pretrain_state_dict is not None and pretrain_config is not None:
|
|
||||||
loaded = load_pretrain_weights_to_model(
|
|
||||||
model, pretrain_state_dict, best_params["d_model"],
|
|
||||||
pretrain_config, load_delivery_head
|
|
||||||
)
|
|
||||||
if loaded:
|
|
||||||
logger.info(f"Loaded pretrain weights for outer fold {outer_fold}")
|
|
||||||
|
|
||||||
# 训练(固定 epoch,不 early-stop)
|
|
||||||
train_result = train_fixed_epochs(
|
|
||||||
model=model,
|
|
||||||
train_loader=train_loader,
|
|
||||||
val_loader=None, # 外层不用验证集
|
|
||||||
device=device,
|
|
||||||
lr=best_params["lr"],
|
|
||||||
weight_decay=best_params["weight_decay"],
|
|
||||||
epochs=epoch_mean,
|
|
||||||
class_weights=class_weights,
|
|
||||||
use_cosine_annealing=True,
|
|
||||||
backbone_lr_ratio=best_params.get("backbone_lr_ratio", 1.0),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载最终权重
|
|
||||||
model.load_state_dict(train_result["final_state"])
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
# 保存模型
|
|
||||||
config = {
|
|
||||||
"d_model": best_params["d_model"],
|
|
||||||
"num_heads": best_params["num_heads"],
|
|
||||||
"n_attn_layers": best_params["n_attn_layers"],
|
|
||||||
"fusion_strategy": best_params["fusion_strategy"],
|
|
||||||
"head_hidden_dim": best_params["head_hidden_dim"],
|
|
||||||
"dropout": best_params["dropout"],
|
|
||||||
"use_mpnn": use_mpnn,
|
|
||||||
}
|
|
||||||
|
|
||||||
torch.save({
|
|
||||||
"model_state_dict": train_result["final_state"],
|
|
||||||
"config": config,
|
|
||||||
"epoch_mean": epoch_mean,
|
|
||||||
"best_params": best_params,
|
|
||||||
}, fold_dir / "model.pt")
|
|
||||||
|
|
||||||
# 保存训练历史
|
|
||||||
with open(fold_dir / "history.json", "w") as f:
|
|
||||||
json.dump(train_result["history"], f, indent=2)
|
|
||||||
|
|
||||||
# 在测试集上评估
|
|
||||||
logger.info("Evaluating on outer test set...")
|
|
||||||
test_metrics = evaluate_on_test(model, test_loader, device)
|
|
||||||
|
|
||||||
with open(fold_dir / "test_metrics.json", "w") as f:
|
|
||||||
json.dump(test_metrics, f, indent=2)
|
|
||||||
|
|
||||||
# 打印测试结果
|
|
||||||
logger.info(f"\nOuter Fold {outer_fold} Test Results:")
|
|
||||||
for task, metrics in test_metrics.items():
|
|
||||||
if "rmse" in metrics:
|
|
||||||
logger.info(f" {task}: RMSE={metrics['rmse']:.4f}, R²={metrics['r2']:.4f}")
|
|
||||||
elif "accuracy" in metrics:
|
|
||||||
logger.info(f" {task}: Acc={metrics['accuracy']:.4f}, F1={metrics['f1']:.4f}")
|
|
||||||
elif "kl_divergence" in metrics:
|
|
||||||
logger.info(f" {task}: KL={metrics['kl_divergence']:.4f}, JS={metrics['js_divergence']:.4f}")
|
|
||||||
|
|
||||||
outer_results.append({
|
|
||||||
"fold": outer_fold,
|
|
||||||
"best_params": best_params,
|
|
||||||
"epoch_mean": epoch_mean,
|
|
||||||
"test_metrics": test_metrics,
|
|
||||||
})
|
|
||||||
|
|
||||||
# 汇总结果
|
# 汇总结果
|
||||||
logger.info("\n" + "=" * 60)
|
logger.info("\n" + "=" * 60)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user