缩小d_model/num_heads/n_attn_layers/hidden_dim或者增大正则化后,欠拟合,恢复默认参数

This commit is contained in:
RYDE-WORK 2026-02-28 17:32:49 +08:00
parent e3f9d9e9db
commit a7db8ffc15

View File

@ -284,11 +284,11 @@ def main(
data_dir: Path = PROCESSED_DATA_DIR / "benchmark", data_dir: Path = PROCESSED_DATA_DIR / "benchmark",
output_dir: Path = MODELS_DIR / "benchmark", output_dir: Path = MODELS_DIR / "benchmark",
# 模型参数 # 模型参数
d_model: int = 128, d_model: int = 256,
num_heads: int = 4, num_heads: int = 8,
n_attn_layers: int = 2, n_attn_layers: int = 4,
fusion_strategy: str = "attention", fusion_strategy: str = "attention",
head_hidden_dim: int = 64, head_hidden_dim: int = 128,
dropout: float = 0.1, dropout: float = 0.1,
# MPNN 参数 # MPNN 参数
use_mpnn: bool = False, use_mpnn: bool = False,