mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
Fallback
This commit is contained in:
parent
74fd012f13
commit
ead579b25c
@ -1,7 +1,6 @@
|
||||
"""Benchmark 脚本:在 baseline 论文公开的 CV 划分上评估模型(仅 delivery 任务)"""
|
||||
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@ -10,7 +9,6 @@ import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
|
||||
from loguru import logger
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import mean_squared_error, r2_score
|
||||
@ -160,7 +158,6 @@ def train_fold(
|
||||
weight_decay: float = 1e-5,
|
||||
epochs: int = 50,
|
||||
patience: int = 10,
|
||||
warmup_epochs: int = 3,
|
||||
config: Optional[Dict] = None,
|
||||
) -> Dict:
|
||||
"""训练单个 fold"""
|
||||
@ -169,19 +166,9 @@ def train_fold(
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
warmup_scheduler = LambdaLR(
|
||||
optimizer, lr_lambda=lambda epoch: (epoch + 1) / warmup_epochs
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="min", factor=0.5, patience=5
|
||||
)
|
||||
cosine_scheduler = CosineAnnealingLR(
|
||||
optimizer, T_max=epochs - warmup_epochs
|
||||
)
|
||||
scheduler = SequentialLR(
|
||||
optimizer,
|
||||
schedulers=[warmup_scheduler, cosine_scheduler],
|
||||
milestones=[warmup_epochs],
|
||||
)
|
||||
|
||||
early_stopping = EarlyStopping(patience=patience)
|
||||
|
||||
best_val_loss = float("inf")
|
||||
@ -211,7 +198,7 @@ def train_fold(
|
||||
"lr": current_lr,
|
||||
})
|
||||
|
||||
scheduler.step()
|
||||
scheduler.step(val_metrics["loss"])
|
||||
|
||||
if val_metrics["loss"] < best_val_loss:
|
||||
best_val_loss = val_metrics["loss"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user