mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
移除微调时对于预训练已固定参数的搜索,增加超参搜索时正则化搜索范围上限(加强正则化)
This commit is contained in:
parent
3a45c0641c
commit
73921bb353
@ -260,18 +260,30 @@ def run_optuna_cv(
|
|||||||
n_samples = len(full_dataset)
|
n_samples = len(full_dataset)
|
||||||
indices = np.arange(n_samples)
|
indices = np.arange(n_samples)
|
||||||
|
|
||||||
def objective(trial: optuna.Trial) -> float:
|
# 固定架构参数(与预训练一致,确保权重完整加载)
|
||||||
# 采样超参数
|
_cfg = pretrain_config or {}
|
||||||
d_model = trial.suggest_categorical("d_model", [128, 256, 512])
|
fixed_d_model = _cfg.get("d_model", 256)
|
||||||
num_heads = trial.suggest_categorical("num_heads", [4, 8])
|
fixed_num_heads = _cfg.get("num_heads", 8)
|
||||||
n_attn_layers = trial.suggest_int("n_attn_layers", 2, 6)
|
fixed_n_attn_layers = _cfg.get("n_attn_layers", 4)
|
||||||
fusion_strategy = trial.suggest_categorical(
|
fixed_fusion_strategy = _cfg.get("fusion_strategy", "attention")
|
||||||
"fusion_strategy", ["attention", "avg", "max"]
|
fixed_head_hidden_dim = _cfg.get("head_hidden_dim", 128)
|
||||||
|
logger.info(
|
||||||
|
f"Fixed architecture params: d_model={fixed_d_model}, num_heads={fixed_num_heads}, "
|
||||||
|
f"n_attn_layers={fixed_n_attn_layers}, fusion={fixed_fusion_strategy}, "
|
||||||
|
f"head_hidden_dim={fixed_head_hidden_dim}"
|
||||||
)
|
)
|
||||||
head_hidden_dim = trial.suggest_categorical("head_hidden_dim", [64, 128, 256])
|
|
||||||
dropout = trial.suggest_float("dropout", 0.05, 0.3)
|
def objective(trial: optuna.Trial) -> float:
|
||||||
lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
|
d_model = fixed_d_model
|
||||||
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
|
num_heads = fixed_num_heads
|
||||||
|
n_attn_layers = fixed_n_attn_layers
|
||||||
|
fusion_strategy = fixed_fusion_strategy
|
||||||
|
head_hidden_dim = fixed_head_hidden_dim
|
||||||
|
|
||||||
|
# 搜索训练超参数
|
||||||
|
dropout = trial.suggest_float("dropout", 0.1, 0.5)
|
||||||
|
lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True)
|
||||||
|
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
|
||||||
|
|
||||||
# 3-fold CV
|
# 3-fold CV
|
||||||
cv = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
|
cv = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
|
||||||
@ -351,7 +363,14 @@ def run_optuna_cv(
|
|||||||
|
|
||||||
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
||||||
|
|
||||||
best_params = study.best_trial.params
|
best_params = dict(study.best_trial.params)
|
||||||
|
best_params.update({
|
||||||
|
"d_model": fixed_d_model,
|
||||||
|
"num_heads": fixed_num_heads,
|
||||||
|
"n_attn_layers": fixed_n_attn_layers,
|
||||||
|
"fusion_strategy": fixed_fusion_strategy,
|
||||||
|
"head_hidden_dim": fixed_head_hidden_dim,
|
||||||
|
})
|
||||||
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
|
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
|
||||||
|
|
||||||
logger.info(f"Best trial: {study.best_trial.number}")
|
logger.info(f"Best trial: {study.best_trial.number}")
|
||||||
|
|||||||
@ -373,18 +373,30 @@ def run_inner_optuna(
|
|||||||
|
|
||||||
inner_strata = strata[inner_train_indices]
|
inner_strata = strata[inner_train_indices]
|
||||||
|
|
||||||
def objective(trial: optuna.Trial) -> float:
|
# 固定架构参数(与预训练一致,确保权重完整加载)
|
||||||
# 采样超参数
|
_cfg = pretrain_config or {}
|
||||||
d_model = trial.suggest_categorical("d_model", [128, 256, 512])
|
fixed_d_model = _cfg.get("d_model", 256)
|
||||||
num_heads = trial.suggest_categorical("num_heads", [4, 8])
|
fixed_num_heads = _cfg.get("num_heads", 8)
|
||||||
n_attn_layers = trial.suggest_int("n_attn_layers", 2, 6)
|
fixed_n_attn_layers = _cfg.get("n_attn_layers", 4)
|
||||||
fusion_strategy = trial.suggest_categorical(
|
fixed_fusion_strategy = _cfg.get("fusion_strategy", "attention")
|
||||||
"fusion_strategy", ["attention", "avg", "max"]
|
fixed_head_hidden_dim = _cfg.get("head_hidden_dim", 128)
|
||||||
|
logger.info(
|
||||||
|
f"Fixed architecture params: d_model={fixed_d_model}, num_heads={fixed_num_heads}, "
|
||||||
|
f"n_attn_layers={fixed_n_attn_layers}, fusion={fixed_fusion_strategy}, "
|
||||||
|
f"head_hidden_dim={fixed_head_hidden_dim}"
|
||||||
)
|
)
|
||||||
head_hidden_dim = trial.suggest_categorical("head_hidden_dim", [64, 128, 256])
|
|
||||||
dropout = trial.suggest_float("dropout", 0.05, 0.3)
|
def objective(trial: optuna.Trial) -> float:
|
||||||
lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
|
d_model = fixed_d_model
|
||||||
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
|
num_heads = fixed_num_heads
|
||||||
|
n_attn_layers = fixed_n_attn_layers
|
||||||
|
fusion_strategy = fixed_fusion_strategy
|
||||||
|
head_hidden_dim = fixed_head_hidden_dim
|
||||||
|
|
||||||
|
# 搜索训练超参数
|
||||||
|
dropout = trial.suggest_float("dropout", 0.1, 0.5)
|
||||||
|
lr = trial.suggest_float("lr", 1e-5, 3e-4, log=True)
|
||||||
|
weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
|
||||||
|
|
||||||
# 内层 3-fold CV
|
# 内层 3-fold CV
|
||||||
inner_cv = StratifiedKFold(
|
inner_cv = StratifiedKFold(
|
||||||
@ -471,7 +483,14 @@ def run_inner_optuna(
|
|||||||
|
|
||||||
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
study.optimize(objective, n_trials=n_trials, show_progress_bar=True)
|
||||||
|
|
||||||
best_params = study.best_trial.params
|
best_params = dict(study.best_trial.params)
|
||||||
|
best_params.update({
|
||||||
|
"d_model": fixed_d_model,
|
||||||
|
"num_heads": fixed_num_heads,
|
||||||
|
"n_attn_layers": fixed_n_attn_layers,
|
||||||
|
"fusion_strategy": fixed_fusion_strategy,
|
||||||
|
"head_hidden_dim": fixed_head_hidden_dim,
|
||||||
|
})
|
||||||
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
|
epoch_mean = study.best_trial.user_attrs.get("epoch_mean", epochs_per_trial)
|
||||||
|
|
||||||
logger.info(f"Best trial: {study.best_trial.number}")
|
logger.info(f"Best trial: {study.best_trial.number}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user