Reuse cacje

This commit is contained in:
RYDE-WORK 2026-02-28 15:35:10 +08:00
parent 13b357ce05
commit f952033b09

View File

@ -372,6 +372,8 @@ def main(
df = pd.read_parquet(fold_dir / f"{split}.parquet")
all_smiles.update(df["smiles"].tolist())
rdkit_cache = None # 跨 fold 共享 RDKit 特征缓存
for fold_dir in fold_dirs:
fold_idx = int(fold_dir.name.split("_")[1])
@ -406,9 +408,13 @@ def main(
)
model = model.to(device)
# 第一个 fold 时做 cache warmup
if fold_idx == 0:
# 第一个 fold 时做 cache warmup,之后复用缓存
if rdkit_cache is None:
warmup_cache(model, list(all_smiles), batch_size=256)
rdkit_cache = model.rdkit_encoder._cache
else:
model.rdkit_encoder._cache = rdkit_cache
logger.info(f"Reusing RDKit cache with {len(rdkit_cache)} entries")
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")