mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 01:27:00 +08:00
统一所有脚本参数统计方式
This commit is contained in:
parent
985f3a1bb0
commit
e3f9d9e9db
@ -416,7 +416,9 @@ def main(
|
||||
model.rdkit_encoder._cache = rdkit_cache
|
||||
logger.info(f"Reusing RDKit cache with {len(rdkit_cache)} entries")
|
||||
|
||||
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
n_params_total = sum(p.numel() for p in model.parameters())
|
||||
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
|
||||
|
||||
# 训练
|
||||
result = train_fold(
|
||||
|
||||
@ -303,8 +303,9 @@ def main(
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info(f"Model parameters: {n_params:,}")
|
||||
n_params_total = sum(p.numel() for p in model.parameters())
|
||||
n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info(f"Model parameters: {n_params_total:,} total, {n_params_trainable:,} trainable")
|
||||
|
||||
# 预热 RDKit 缓存(避免训练时阻塞)
|
||||
all_smiles = train_df["smiles"].tolist() + val_df["smiles"].tolist()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user