diff --git a/lnp_ml/modeling/benchmark.py b/lnp_ml/modeling/benchmark.py index 44ba209..d023170 100644 --- a/lnp_ml/modeling/benchmark.py +++ b/lnp_ml/modeling/benchmark.py @@ -372,6 +372,8 @@ def main( df = pd.read_parquet(fold_dir / f"{split}.parquet") all_smiles.update(df["smiles"].tolist()) + rdkit_cache = None # 跨 fold 共享 RDKit 特征缓存 + for fold_dir in fold_dirs: fold_idx = int(fold_dir.name.split("_")[1]) @@ -406,9 +408,13 @@ def main( ) model = model.to(device) - # 第一个 fold 时做 cache warmup - if fold_idx == 0: + # 第一个 fold 时做 cache warmup,之后复用缓存 + if rdkit_cache is None: warmup_cache(model, list(all_smiles), batch_size=256) + rdkit_cache = model.rdkit_encoder._cache + else: + model.rdkit_encoder._cache = rdkit_cache + logger.info(f"Reusing RDKit cache with {len(rdkit_cache)} entries") logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")