mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
缩小d_model/num_heads/n_attn_layers/hidden_dim或者增大正则化后,欠拟合,恢复默认参数
This commit is contained in:
parent
e3f9d9e9db
commit
a7db8ffc15
@ -284,11 +284,11 @@ def main(
|
|||||||
data_dir: Path = PROCESSED_DATA_DIR / "benchmark",
|
data_dir: Path = PROCESSED_DATA_DIR / "benchmark",
|
||||||
output_dir: Path = MODELS_DIR / "benchmark",
|
output_dir: Path = MODELS_DIR / "benchmark",
|
||||||
# 模型参数
|
# 模型参数
|
||||||
d_model: int = 128,
|
d_model: int = 256,
|
||||||
num_heads: int = 4,
|
num_heads: int = 8,
|
||||||
n_attn_layers: int = 2,
|
n_attn_layers: int = 4,
|
||||||
fusion_strategy: str = "attention",
|
fusion_strategy: str = "attention",
|
||||||
head_hidden_dim: int = 64,
|
head_hidden_dim: int = 128,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
# MPNN 参数
|
# MPNN 参数
|
||||||
use_mpnn: bool = False,
|
use_mpnn: bool = False,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user