统一所有脚本参数统计方式

This commit is contained in:
RYDE-WORK 2026-02-28 17:08:22 +08:00
parent 985f3a1bb0
commit e3f9d9e9db
2 changed files with 6 additions and 3 deletions

View File

@ -416,7 +416,9 @@ def main(
model.rdkit_encoder._cache = rdkit_cache
logger.info(f"Reusing RDKit cache with {len(rdkit_cache)} entries")
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
n_params_total = sum(p.numel() for p in model.parameters())
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
# 训练
result = train_fold(

View File

@ -303,8 +303,9 @@ def main(
dropout=dropout,
)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params:,}")
n_params_total = sum(p.numel() for p in model.parameters())
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
# 预热 RDKit 缓存(避免训练时阻塞)
all_smiles = train_df["smiles"].tolist() + val_df["smiles"].tolist()