From e3f9d9e9db015b4ea5a6155f37f1762911b3c88d Mon Sep 17 00:00:00 2001 From: RYDE-WORK Date: Sat, 28 Feb 2026 17:08:22 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=9F=E4=B8=80=E6=89=80=E6=9C=89=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=E5=8F=82=E6=95=B0=E7=BB=9F=E8=AE=A1=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lnp_ml/modeling/benchmark.py | 4 +++- lnp_ml/modeling/pretrain.py | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lnp_ml/modeling/benchmark.py b/lnp_ml/modeling/benchmark.py index 2c6be78..3ad7ac9 100644 --- a/lnp_ml/modeling/benchmark.py +++ b/lnp_ml/modeling/benchmark.py @@ -416,7 +416,9 @@ def main( model.rdkit_encoder._cache = rdkit_cache logger.info(f"Reusing RDKit cache with {len(rdkit_cache)} entries") - logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + n_params_total = sum(p.numel() for p in model.parameters()) + n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable") # 训练 result = train_fold( diff --git a/lnp_ml/modeling/pretrain.py b/lnp_ml/modeling/pretrain.py index abb702d..b7d9d66 100644 --- a/lnp_ml/modeling/pretrain.py +++ b/lnp_ml/modeling/pretrain.py @@ -303,8 +303,9 @@ def main( dropout=dropout, ) - n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info(f"Model parameters: {n_params:,}") + n_params_total = sum(p.numel() for p in model.parameters()) + n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable") # 预热 RDKit 缓存(避免训练时阻塞) all_smiles = train_df["smiles"].tolist() + val_df["smiles"].tolist()