diff --git a/lnp_ml/modeling/final_train_optuna_cv.py b/lnp_ml/modeling/final_train_optuna_cv.py index 6b5de89..a439b82 100644 --- a/lnp_ml/modeling/final_train_optuna_cv.py +++ b/lnp_ml/modeling/final_train_optuna_cv.py @@ -284,6 +284,7 @@ def run_optuna_cv( dropout = trial.suggest_float("dropout", 0.1, 0.5) lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True) weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True) + backbone_lr_ratio = trial.suggest_float("backbone_lr_ratio", 0.01, 1.0, log=True) # 3-fold CV cv = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed) @@ -335,6 +336,7 @@ def run_optuna_cv( epochs=epochs_per_trial, patience=patience, class_weights=class_weights, + backbone_lr_ratio=backbone_lr_ratio, ) fold_val_losses.append(result["best_val_loss"]) @@ -569,6 +571,7 @@ def main( use_cosine_annealing=True, use_swa=use_swa, swa_start_epoch=swa_start, + backbone_lr_ratio=best_params.get("backbone_lr_ratio", 1.0), ) # 加载最终权重 diff --git a/lnp_ml/modeling/nested_cv_optuna.py b/lnp_ml/modeling/nested_cv_optuna.py index fa0f38b..dc78aef 100644 --- a/lnp_ml/modeling/nested_cv_optuna.py +++ b/lnp_ml/modeling/nested_cv_optuna.py @@ -397,6 +397,7 @@ def run_inner_optuna( dropout = trial.suggest_float("dropout", 0.1, 0.5) lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True) weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True) + backbone_lr_ratio = trial.suggest_float("backbone_lr_ratio", 0.01, 1.0, log=True) # 内层 3-fold CV inner_cv = StratifiedKFold( @@ -456,6 +457,7 @@ def run_inner_optuna( epochs=epochs_per_trial, patience=patience, class_weights=class_weights, + backbone_lr_ratio=backbone_lr_ratio, ) fold_val_losses.append(result["best_val_loss"]) @@ -685,6 +687,7 @@ def main( epochs=epoch_mean, class_weights=class_weights, use_cosine_annealing=True, + backbone_lr_ratio=best_params.get("backbone_lr_ratio", 1.0), ) # 加载最终权重 diff --git a/lnp_ml/modeling/trainer_balanced.py b/lnp_ml/modeling/trainer_balanced.py index bb2b803..e67c577 100644 --- a/lnp_ml/modeling/trainer_balanced.py +++ b/lnp_ml/modeling/trainer_balanced.py @@ -286,6 +286,43 @@ def validate_balanced( return metrics +BACKBONE_PREFIXES = ("token_projector.", "cross_attention.", "fusion.") + + +def build_optimizer( + model: nn.Module, + lr: float, + weight_decay: float, + backbone_lr_ratio: float = 1.0, +) -> torch.optim.AdamW: + """ + 构建 AdamW 优化器,支持分层学习率。 + + 当 backbone_lr_ratio < 1.0 时,backbone 参数使用 lr * backbone_lr_ratio, + 其余参数(task heads 等)使用 lr。backbone_lr_ratio = 1.0 等价于统一学习率。 + """ + if backbone_lr_ratio >= 1.0: + return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + + backbone_params = [] + head_params = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + if name.startswith(BACKBONE_PREFIXES): + backbone_params.append(param) + else: + head_params.append(param) + + return torch.optim.AdamW( + [ + {"params": backbone_params, "lr": lr * backbone_lr_ratio}, + {"params": head_params, "lr": lr}, + ], + weight_decay=weight_decay, + ) + + class EarlyStoppingBalanced: """早停机制(与 trainer.py 的 EarlyStopping 兼容)""" @@ -324,6 +361,7 @@ def train_with_early_stopping( patience: int = 15, task_weights: Optional[LossWeightsBalanced] = None, class_weights: Optional[ClassWeights] = None, + backbone_lr_ratio: float = 1.0, ) -> Dict: """ 带早停的完整训练流程。 @@ -332,7 +370,7 @@ def train_with_early_stopping( Dict with keys: history, best_val_loss, best_epoch, best_state """ model = model.to(device) - optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=5 ) @@ -394,6 +432,7 @@ def train_fixed_epochs( use_cosine_annealing: bool = True, use_swa: bool = False, swa_start_epoch: Optional[int] = None, + backbone_lr_ratio: float = 1.0, ) -> Dict: """ 固定 epoch 数的训练(不使用 early stopping)。 @@ -413,12 +452,13 @@ def train_fixed_epochs( use_cosine_annealing: 是否使用 CosineAnnealingLR use_swa: 是否使用 SWA swa_start_epoch: SWA 开始的 epoch(默认为 epochs * 0.75) + backbone_lr_ratio: backbone 学习率相对于 head 的比例(1.0 = 统一学习率) Returns: Dict with keys: history, final_state """ model = model.to(device) - optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + optimizer = build_optimizer(model, lr, weight_decay, backbone_lr_ratio) if use_cosine_annealing: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)