微调阶段rdkit缓存共享

This commit is contained in:
RYDE-WORK 2026-02-28 15:40:29 +08:00
parent f952033b09
commit 7c69d47238
2 changed files with 50 additions and 0 deletions

View File

@ -41,7 +41,10 @@ from lnp_ml.dataset import (
TARGET_CLASSIFICATION_EE, TARGET_CLASSIFICATION_EE,
TARGET_TOXIC, TARGET_TOXIC,
) )
from tqdm import tqdm
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
from lnp_ml.modeling.encoders.rdkit_encoder import CachedRDKitEncoder
from lnp_ml.modeling.trainer_balanced import ( from lnp_ml.modeling.trainer_balanced import (
ClassWeights, ClassWeights,
LossWeightsBalanced, LossWeightsBalanced,
@ -140,6 +143,20 @@ def build_composite_strata(
return encoded.astype(np.int64), strata_info return encoded.astype(np.int64), strata_info
# ============ RDKit 缓存预热 ============
def warmup_rdkit_cache(smiles_list: List[str], batch_size: int = 256) -> Dict:
"""预热 RDKit 特征缓存,返回可跨模型共享的缓存字典。"""
encoder = CachedRDKitEncoder()
unique_smiles = list(set(smiles_list))
logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...")
for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"):
batch = unique_smiles[i:i + batch_size]
encoder(batch)
logger.success(f"Cache warmup complete. Cached {len(encoder._cache)} SMILES.")
return encoder._cache
# ============ 模型创建 ============ # ============ 模型创建 ============
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]: def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
@ -231,6 +248,7 @@ def run_optuna_cv(
pretrain_state_dict: Optional[Dict] = None, pretrain_state_dict: Optional[Dict] = None,
pretrain_config: Optional[Dict] = None, pretrain_config: Optional[Dict] = None,
load_delivery_head: bool = True, load_delivery_head: bool = True,
rdkit_cache: Optional[Dict] = None,
) -> Tuple[Dict, int, optuna.Study]: ) -> Tuple[Dict, int, optuna.Study]:
""" """
使用全量数据做 3-fold CV Optuna 超参搜索 使用全量数据做 3-fold CV Optuna 超参搜索
@ -318,6 +336,8 @@ def run_optuna_cv(
use_mpnn=use_mpnn, use_mpnn=use_mpnn,
mpnn_device=device.type, mpnn_device=device.type,
) )
if rdkit_cache is not None:
model.rdkit_encoder._cache = rdkit_cache
# 加载预训练权重 # 加载预训练权重
if pretrain_state_dict is not None and pretrain_config is not None: if pretrain_state_dict is not None and pretrain_config is not None:
@ -464,6 +484,9 @@ def main(
# 创建完整数据集 # 创建完整数据集
full_dataset = LNPDataset(df) full_dataset = LNPDataset(df)
# 预热 RDKit 缓存(在整个训练流程中共享)
rdkit_cache = warmup_rdkit_cache(full_dataset.smiles)
# 运行 Optuna 调参 # 运行 Optuna 调参
logger.info(f"\nRunning {n_folds}-fold Optuna with {n_trials} trials...") logger.info(f"\nRunning {n_folds}-fold Optuna with {n_trials} trials...")
study_path = output_dir / "optuna_study.sqlite3" study_path = output_dir / "optuna_study.sqlite3"
@ -483,6 +506,7 @@ def main(
pretrain_state_dict=pretrain_state_dict, pretrain_state_dict=pretrain_state_dict,
pretrain_config=pretrain_config, pretrain_config=pretrain_config,
load_delivery_head=load_delivery_head, load_delivery_head=load_delivery_head,
rdkit_cache=rdkit_cache,
) )
# 保存最佳参数 # 保存最佳参数
@ -541,6 +565,7 @@ def main(
use_mpnn=use_mpnn, use_mpnn=use_mpnn,
mpnn_device=device.type, mpnn_device=device.type,
) )
model.rdkit_encoder._cache = rdkit_cache
# 加载预训练权重 # 加载预训练权重
if pretrain_state_dict is not None and pretrain_config is not None: if pretrain_state_dict is not None and pretrain_config is not None:

View File

@ -41,7 +41,10 @@ from lnp_ml.dataset import (
TARGET_CLASSIFICATION_EE, TARGET_CLASSIFICATION_EE,
TARGET_TOXIC, TARGET_TOXIC,
) )
from tqdm import tqdm
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
from lnp_ml.modeling.encoders.rdkit_encoder import CachedRDKitEncoder
from lnp_ml.modeling.trainer_balanced import ( from lnp_ml.modeling.trainer_balanced import (
ClassWeights, ClassWeights,
LossWeightsBalanced, LossWeightsBalanced,
@ -142,6 +145,20 @@ def build_composite_strata(
return encoded.astype(np.int64), strata_info return encoded.astype(np.int64), strata_info
# ============ RDKit 缓存预热 ============
def warmup_rdkit_cache(smiles_list: List[str], batch_size: int = 256) -> Dict:
"""预热 RDKit 特征缓存,返回可跨模型共享的缓存字典。"""
encoder = CachedRDKitEncoder()
unique_smiles = list(set(smiles_list))
logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...")
for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"):
batch = unique_smiles[i:i + batch_size]
encoder(batch)
logger.success(f"Cache warmup complete. Cached {len(encoder._cache)} SMILES.")
return encoder._cache
# ============ 模型创建 ============ # ============ 模型创建 ============
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]: def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
@ -344,6 +361,7 @@ def run_inner_optuna(
pretrain_state_dict: Optional[Dict] = None, pretrain_state_dict: Optional[Dict] = None,
pretrain_config: Optional[Dict] = None, pretrain_config: Optional[Dict] = None,
load_delivery_head: bool = True, load_delivery_head: bool = True,
rdkit_cache: Optional[Dict] = None,
) -> Tuple[Dict, int, optuna.Study]: ) -> Tuple[Dict, int, optuna.Study]:
""" """
在内层数据上运行 Optuna 超参搜索 在内层数据上运行 Optuna 超参搜索
@ -439,6 +457,8 @@ def run_inner_optuna(
use_mpnn=use_mpnn, use_mpnn=use_mpnn,
mpnn_device=device.type, mpnn_device=device.type,
) )
if rdkit_cache is not None:
model.rdkit_encoder._cache = rdkit_cache
# 加载预训练权重 # 加载预训练权重
if pretrain_state_dict is not None and pretrain_config is not None: if pretrain_state_dict is not None and pretrain_config is not None:
@ -540,6 +560,9 @@ def _run_single_outer_fold(
logger.info(f"{'='*60}") logger.info(f"{'='*60}")
logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}") logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}")
# 预热 RDKit 缓存(在整个 fold 内共享)
rdkit_cache = warmup_rdkit_cache(full_dataset.smiles)
# 保存 split indices # 保存 split indices
with open(fold_dir / "splits.json", "w") as f: with open(fold_dir / "splits.json", "w") as f:
json.dump({ json.dump({
@ -567,6 +590,7 @@ def _run_single_outer_fold(
pretrain_state_dict=pretrain_state_dict, pretrain_state_dict=pretrain_state_dict,
pretrain_config=pretrain_config, pretrain_config=pretrain_config,
load_delivery_head=load_delivery_head, load_delivery_head=load_delivery_head,
rdkit_cache=rdkit_cache,
) )
# 保存最佳参数 # 保存最佳参数
@ -601,6 +625,7 @@ def _run_single_outer_fold(
use_mpnn=use_mpnn, use_mpnn=use_mpnn,
mpnn_device=device.type, mpnn_device=device.type,
) )
model.rdkit_encoder._cache = rdkit_cache
if pretrain_state_dict is not None and pretrain_config is not None: if pretrain_state_dict is not None and pretrain_config is not None:
loaded = load_pretrain_weights_to_model( loaded = load_pretrain_weights_to_model(