mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 01:27:00 +08:00
Reuse cacje
This commit is contained in:
parent
13b357ce05
commit
f952033b09
@ -372,6 +372,8 @@ def main(
|
||||
df = pd.read_parquet(fold_dir / f"{split}.parquet")
|
||||
all_smiles.update(df["smiles"].tolist())
|
||||
|
||||
rdkit_cache = None # 跨 fold 共享 RDKit 特征缓存
|
||||
|
||||
for fold_dir in fold_dirs:
|
||||
fold_idx = int(fold_dir.name.split("_")[1])
|
||||
|
||||
@ -406,9 +408,13 @@ def main(
|
||||
)
|
||||
model = model.to(device)
|
||||
|
||||
# 第一个 fold 时做 cache warmup
|
||||
if fold_idx == 0:
|
||||
# 第一个 fold 时做 cache warmup,之后复用缓存
|
||||
if rdkit_cache is None:
|
||||
warmup_cache(model, list(all_smiles), batch_size=256)
|
||||
rdkit_cache = model.rdkit_encoder._cache
|
||||
else:
|
||||
model.rdkit_encoder._cache = rdkit_cache
|
||||
logger.info(f"Reusing RDKit cache with {len(rdkit_cache)} entries")
|
||||
|
||||
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user