统一所有脚本参数统计方式

This commit is contained in:
RYDE-WORK 2026-02-28 17:08:22 +08:00
parent 985f3a1bb0
commit e3f9d9e9db
2 changed files with 6 additions and 3 deletions

View File

@ -416,7 +416,9 @@ def main(
model.rdkit_encoder._cache = rdkit_cache model.rdkit_encoder._cache = rdkit_cache
logger.info(f"Reusing RDKit cache with {len(rdkit_cache)} entries") logger.info(f"Reusing RDKit cache with {len(rdkit_cache)} entries")
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") n_params_total = sum(p.numel() for p in model.parameters())
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
# 训练 # 训练
result = train_fold( result = train_fold(

View File

@ -303,8 +303,9 @@ def main(
dropout=dropout, dropout=dropout,
) )
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) n_params_total = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters: {n_params:,}") n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
# 预热 RDKit 缓存(避免训练时阻塞) # 预热 RDKit 缓存(避免训练时阻塞)
all_smiles = train_df["smiles"].tolist() + val_df["smiles"].tolist() all_smiles = train_df["smiles"].tolist() + val_df["smiles"].tolist()