mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
对backbone与预测头采用分层学习率,缓解预训练任务数与任务定义都不一致的问题
This commit is contained in:
parent
73921bb353
commit
c52b82786d
@ -284,6 +284,7 @@ def run_optuna_cv(
|
|||||||
dropout = trial.suggest_float("dropout", 0.1, 0.5)
|
dropout = trial.suggest_float("dropout", 0.1, 0.5)
|
||||||
lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True)
|
lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True)
|
||||||
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
|
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
|
||||||
|
backbone_lr_ratio = trial.suggest_float("backbone_lr_ratio", 0.01, 1.0, log=True)
|
||||||
|
|
||||||
# 3-fold CV
|
# 3-fold CV
|
||||||
cv = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
|
cv = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
|
||||||
@ -335,6 +336,7 @@ def run_optuna_cv(
|
|||||||
epochs=epochs_per_trial,
|
epochs=epochs_per_trial,
|
||||||
patience=patience,
|
patience=patience,
|
||||||
class_weights=class_weights,
|
class_weights=class_weights,
|
||||||
|
backbone_lr_ratio=backbone_lr_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
fold_val_losses.append(result["best_val_loss"])
|
fold_val_losses.append(result["best_val_loss"])
|
||||||
@ -569,6 +571,7 @@ def main(
|
|||||||
use_cosine_annealing=True,
|
use_cosine_annealing=True,
|
||||||
use_swa=use_swa,
|
use_swa=use_swa,
|
||||||
swa_start_epoch=swa_start,
|
swa_start_epoch=swa_start,
|
||||||
|
backbone_lr_ratio=best_params.get("backbone_lr_ratio", 1.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 加载最终权重
|
# 加载最终权重
|
||||||
|
|||||||
@ -397,6 +397,7 @@ def run_inner_optuna(
|
|||||||
dropout = trial.suggest_float("dropout", 0.1, 0.5)
|
dropout = trial.suggest_float("dropout", 0.1, 0.5)
|
||||||
lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True)
|
lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True)
|
||||||
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
|
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
|
||||||
|
backbone_lr_ratio = trial.suggest_float("backbone_lr_ratio", 0.01, 1.0, log=True)
|
||||||
|
|
||||||
# 内层 3-fold CV
|
# 内层 3-fold CV
|
||||||
inner_cv = StratifiedKFold(
|
inner_cv = StratifiedKFold(
|
||||||
@ -456,6 +457,7 @@ def run_inner_optuna(
|
|||||||
epochs=epochs_per_trial,
|
epochs=epochs_per_trial,
|
||||||
patience=patience,
|
patience=patience,
|
||||||
class_weights=class_weights,
|
class_weights=class_weights,
|
||||||
|
backbone_lr_ratio=backbone_lr_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
fold_val_losses.append(result["best_val_loss"])
|
fold_val_losses.append(result["best_val_loss"])
|
||||||
@ -685,6 +687,7 @@ def main(
|
|||||||
epochs=epoch_mean,
|
epochs=epoch_mean,
|
||||||
class_weights=class_weights,
|
class_weights=class_weights,
|
||||||
use_cosine_annealing=True,
|
use_cosine_annealing=True,
|
||||||
|
backbone_lr_ratio=best_params.get("backbone_lr_ratio", 1.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 加载最终权重
|
# 加载最终权重
|
||||||
|
|||||||
@ -286,6 +286,43 @@ def validate_balanced(
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
BACKBONE_PREFIXES = ("token_projector.", "cross_attention.", "fusion.")
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer(
|
||||||
|
model: nn.Module,
|
||||||
|
lr: float,
|
||||||
|
weight_decay: float,
|
||||||
|
backbone_lr_ratio: float = 1.0,
|
||||||
|
) -> torch.optim.AdamW:
|
||||||
|
"""
|
||||||
|
构建 AdamW 优化器,支持分层学习率。
|
||||||
|
|
||||||
|
当 backbone_lr_ratio < 1.0 时,backbone 参数使用 lr * backbone_lr_ratio,
|
||||||
|
其余参数(task heads 等)使用 lr。backbone_lr_ratio = 1.0 等价于统一学习率。
|
||||||
|
"""
|
||||||
|
if backbone_lr_ratio >= 1.0:
|
||||||
|
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
|
||||||
|
backbone_params = []
|
||||||
|
head_params = []
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.startswith(BACKBONE_PREFIXES):
|
||||||
|
backbone_params.append(param)
|
||||||
|
else:
|
||||||
|
head_params.append(param)
|
||||||
|
|
||||||
|
return torch.optim.AdamW(
|
||||||
|
[
|
||||||
|
{"params": backbone_params, "lr": lr * backbone_lr_ratio},
|
||||||
|
{"params": head_params, "lr": lr},
|
||||||
|
],
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EarlyStoppingBalanced:
|
class EarlyStoppingBalanced:
|
||||||
"""早停机制(与 trainer.py 的 EarlyStopping 兼容)"""
|
"""早停机制(与 trainer.py 的 EarlyStopping 兼容)"""
|
||||||
|
|
||||||
@ -324,6 +361,7 @@ def train_with_early_stopping(
|
|||||||
patience: int = 15,
|
patience: int = 15,
|
||||||
task_weights: Optional[LossWeightsBalanced] = None,
|
task_weights: Optional[LossWeightsBalanced] = None,
|
||||||
class_weights: Optional[ClassWeights] = None,
|
class_weights: Optional[ClassWeights] = None,
|
||||||
|
backbone_lr_ratio: float = 1.0,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
带早停的完整训练流程。
|
带早停的完整训练流程。
|
||||||
@ -332,7 +370,7 @@ def train_with_early_stopping(
|
|||||||
Dict with keys: history, best_val_loss, best_epoch, best_state
|
Dict with keys: history, best_val_loss, best_epoch, best_state
|
||||||
"""
|
"""
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
optimizer, mode="min", factor=0.5, patience=5
|
optimizer, mode="min", factor=0.5, patience=5
|
||||||
)
|
)
|
||||||
@ -394,6 +432,7 @@ def train_fixed_epochs(
|
|||||||
use_cosine_annealing: bool = True,
|
use_cosine_annealing: bool = True,
|
||||||
use_swa: bool = False,
|
use_swa: bool = False,
|
||||||
swa_start_epoch: Optional[int] = None,
|
swa_start_epoch: Optional[int] = None,
|
||||||
|
backbone_lr_ratio: float = 1.0,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
固定 epoch 数的训练(不使用 early stopping)。
|
固定 epoch 数的训练(不使用 early stopping)。
|
||||||
@ -413,12 +452,13 @@ def train_fixed_epochs(
|
|||||||
use_cosine_annealing: 是否使用 CosineAnnealingLR
|
use_cosine_annealing: 是否使用 CosineAnnealingLR
|
||||||
use_swa: 是否使用 SWA
|
use_swa: 是否使用 SWA
|
||||||
swa_start_epoch: SWA 开始的 epoch(默认为 epochs * 0.75)
|
swa_start_epoch: SWA 开始的 epoch(默认为 epochs * 0.75)
|
||||||
|
backbone_lr_ratio: backbone 学习率相对于 head 的比例(1.0 = 统一学习率)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with keys: history, final_state
|
Dict with keys: history, final_state
|
||||||
"""
|
"""
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio)
|
||||||
|
|
||||||
if use_cosine_annealing:
|
if use_cosine_annealing:
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user