This commit is contained in:
RYDE-WORK 2026-02-28 18:29:07 +08:00
parent 74fd012f13
commit ead579b25c

View File

@ -1,7 +1,6 @@
"""Benchmark 脚本:在 baseline 论文公开的 CV 划分上评估模型(仅 delivery 任务)""" """Benchmark 脚本:在 baseline 论文公开的 CV 划分上评估模型(仅 delivery 任务)"""
import json import json
import math
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
@ -10,7 +9,6 @@ import pandas as pd
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
from loguru import logger from loguru import logger
from tqdm import tqdm from tqdm import tqdm
from sklearn.metrics import mean_squared_error, r2_score from sklearn.metrics import mean_squared_error, r2_score
@ -160,7 +158,6 @@ def train_fold(
weight_decay: float = 1e-5, weight_decay: float = 1e-5,
epochs: int = 50, epochs: int = 50,
patience: int = 10, patience: int = 10,
warmup_epochs: int = 3,
config: Optional[Dict] = None, config: Optional[Dict] = None,
) -> Dict: ) -> Dict:
"""训练单个 fold""" """训练单个 fold"""
@ -169,19 +166,9 @@ def train_fold(
logger.info(f"{'='*60}") logger.info(f"{'='*60}")
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
warmup_scheduler = LambdaLR( optimizer, mode="min", factor=0.5, patience=5
optimizer, lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs
) )
cosine_scheduler = CosineAnnealingLR(
optimizer, T_max=epochs - warmup_epochs
)
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_epochs],
)
early_stopping = EarlyStopping(patience=patience) early_stopping = EarlyStopping(patience=patience)
best_val_loss = float("inf") best_val_loss = float("inf")
@ -211,7 +198,7 @@ def train_fold(
"lr": current_lr, "lr": current_lr,
}) })
scheduler.step() scheduler.step(val_metrics["loss"])
if val_metrics["loss"] < best_val_loss: if val_metrics["loss"] < best_val_loss:
best_val_loss = val_metrics["loss"] best_val_loss = val_metrics["loss"]