mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
Reuse cacje
This commit is contained in:
parent
13b357ce05
commit
f952033b09
@ -372,6 +372,8 @@ def main(
|
|||||||
df = pd.read_parquet(fold_dir / f"{split}.parquet")
|
df = pd.read_parquet(fold_dir / f"{split}.parquet")
|
||||||
all_smiles.update(df["smiles"].tolist())
|
all_smiles.update(df["smiles"].tolist())
|
||||||
|
|
||||||
|
rdkit_cache = None # 跨 fold 共享 RDKit 特征缓存
|
||||||
|
|
||||||
for fold_dir in fold_dirs:
|
for fold_dir in fold_dirs:
|
||||||
fold_idx = int(fold_dir.name.split("_")[1])
|
fold_idx = int(fold_dir.name.split("_")[1])
|
||||||
|
|
||||||
@ -406,9 +408,13 @@ def main(
|
|||||||
)
|
)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
# 第一个 fold 时做 cache warmup
|
# 第一个 fold 时做 cache warmup,之后复用缓存
|
||||||
if fold_idx == 0:
|
if rdkit_cache is None:
|
||||||
warmup_cache(model, list(all_smiles), batch_size=256)
|
warmup_cache(model, list(all_smiles), batch_size=256)
|
||||||
|
rdkit_cache = model.rdkit_encoder._cache
|
||||||
|
else:
|
||||||
|
model.rdkit_encoder._cache = rdkit_cache
|
||||||
|
logger.info(f"Reusing RDKit cache with {len(rdkit_cache)} entries")
|
||||||
|
|
||||||
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user