diff --git a/lnp_ml/modeling/benchmark.py b/lnp_ml/modeling/benchmark.py index 3ad7ac9..6ac6b20 100644 --- a/lnp_ml/modeling/benchmark.py +++ b/lnp_ml/modeling/benchmark.py @@ -284,11 +284,11 @@ def main( data_dir: Path = PROCESSED_DATA_DIR / "benchmark", output_dir: Path = MODELS_DIR / "benchmark", # 模型参数 - d_model: int = 128, - num_heads: int = 4, - n_attn_layers: int = 2, + d_model: int = 256, + num_heads: int = 8, + n_attn_layers: int = 4, fusion_strategy: str = "attention", - head_hidden_dim: int = 64, + head_hidden_dim: int = 128, dropout: float = 0.1, # MPNN 参数 use_mpnn: bool = False,