From 7c69d472389b785b1b9e71019154e8b03d3eec0f Mon Sep 17 00:00:00 2001 From: RYDE-WORK Date: Sat, 28 Feb 2026 15:40:29 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BE=AE=E8=B0=83=E9=98=B6=E6=AE=B5rdkit?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E5=85=B1=E4=BA=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lnp_ml/modeling/final_train_optuna_cv.py | 25 ++++++++++++++++++++++++ lnp_ml/modeling/nested_cv_optuna.py | 25 ++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/lnp_ml/modeling/final_train_optuna_cv.py b/lnp_ml/modeling/final_train_optuna_cv.py index 8046e4e..9e36152 100644 --- a/lnp_ml/modeling/final_train_optuna_cv.py +++ b/lnp_ml/modeling/final_train_optuna_cv.py @@ -41,7 +41,10 @@ from lnp_ml.dataset import ( TARGET_CLASSIFICATION_EE, TARGET_TOXIC, ) +from tqdm import tqdm + from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN +from lnp_ml.modeling.encoders.rdkit_encoder import CachedRDKitEncoder from lnp_ml.modeling.trainer_balanced import ( ClassWeights, LossWeightsBalanced, @@ -140,6 +143,20 @@ def build_composite_strata( return encoded.astype(np.int64), strata_info +# ============ RDKit 缓存预热 ============ + +def warmup_rdkit_cache(smiles_list: List[str], batch_size: int = 256) -> Dict: + """预热 RDKit 特征缓存,返回可跨模型共享的缓存字典。""" + encoder = CachedRDKitEncoder() + unique_smiles = list(set(smiles_list)) + logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...") + for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"): + batch = unique_smiles[i:i + batch_size] + encoder(batch) + logger.success(f"Cache warmup complete. Cached {len(encoder._cache)} SMILES.") + return encoder._cache + + # ============ 模型创建 ============ def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]: @@ -231,6 +248,7 @@ def run_optuna_cv( pretrain_state_dict: Optional[Dict] = None, pretrain_config: Optional[Dict] = None, load_delivery_head: bool = True, + rdkit_cache: Optional[Dict] = None, ) -> Tuple[Dict, int, optuna.Study]: """ 使用全量数据做 3-fold CV Optuna 超参搜索。 @@ -318,6 +336,8 @@ def run_optuna_cv( use_mpnn=use_mpnn, mpnn_device=device.type, ) + if rdkit_cache is not None: + model.rdkit_encoder._cache = rdkit_cache # 加载预训练权重 if pretrain_state_dict is not None and pretrain_config is not None: @@ -464,6 +484,9 @@ def main( # 创建完整数据集 full_dataset = LNPDataset(df) + # 预热 RDKit 缓存(在整个训练流程中共享) + rdkit_cache = warmup_rdkit_cache(full_dataset.smiles) + # 运行 Optuna 调参 logger.info(f"\nRunning {n_folds}-fold Optuna with {n_trials} trials...") study_path = output_dir / "optuna_study.sqlite3" @@ -483,6 +506,7 @@ def main( pretrain_state_dict=pretrain_state_dict, pretrain_config=pretrain_config, load_delivery_head=load_delivery_head, + rdkit_cache=rdkit_cache, ) # 保存最佳参数 @@ -541,6 +565,7 @@ def main( use_mpnn=use_mpnn, mpnn_device=device.type, ) + model.rdkit_encoder._cache = rdkit_cache # 加载预训练权重 if pretrain_state_dict is not None and pretrain_config is not None: diff --git a/lnp_ml/modeling/nested_cv_optuna.py b/lnp_ml/modeling/nested_cv_optuna.py index df323dc..c2c24e4 100644 --- a/lnp_ml/modeling/nested_cv_optuna.py +++ b/lnp_ml/modeling/nested_cv_optuna.py @@ -41,7 +41,10 @@ from lnp_ml.dataset import ( TARGET_CLASSIFICATION_EE, TARGET_TOXIC, ) +from tqdm import tqdm + from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN +from lnp_ml.modeling.encoders.rdkit_encoder import CachedRDKitEncoder from lnp_ml.modeling.trainer_balanced import ( ClassWeights, LossWeightsBalanced, @@ -142,6 +145,20 @@ def build_composite_strata( return encoded.astype(np.int64), strata_info +# ============ RDKit 缓存预热 ============ + +def warmup_rdkit_cache(smiles_list: List[str], batch_size: int = 256) -> Dict: + """预热 RDKit 特征缓存,返回可跨模型共享的缓存字典。""" + encoder = CachedRDKitEncoder() + unique_smiles = list(set(smiles_list)) + logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...") + for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"): + batch = unique_smiles[i:i + batch_size] + encoder(batch) + logger.success(f"Cache warmup complete. Cached {len(encoder._cache)} SMILES.") + return encoder._cache + + # ============ 模型创建 ============ def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]: @@ -344,6 +361,7 @@ def run_inner_optuna( pretrain_state_dict: Optional[Dict] = None, pretrain_config: Optional[Dict] = None, load_delivery_head: bool = True, + rdkit_cache: Optional[Dict] = None, ) -> Tuple[Dict, int, optuna.Study]: """ 在内层数据上运行 Optuna 超参搜索。 @@ -439,6 +457,8 @@ def run_inner_optuna( use_mpnn=use_mpnn, mpnn_device=device.type, ) + if rdkit_cache is not None: + model.rdkit_encoder._cache = rdkit_cache # 加载预训练权重 if pretrain_state_dict is not None and pretrain_config is not None: @@ -540,6 +560,9 @@ def _run_single_outer_fold( logger.info(f"{'='*60}") logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}") + # 预热 RDKit 缓存(在整个 fold 内共享) + rdkit_cache = warmup_rdkit_cache(full_dataset.smiles) + # 保存 split indices with open(fold_dir / "splits.json", "w") as f: json.dump({ @@ -567,6 +590,7 @@ def _run_single_outer_fold( pretrain_state_dict=pretrain_state_dict, pretrain_config=pretrain_config, load_delivery_head=load_delivery_head, + rdkit_cache=rdkit_cache, ) # 保存最佳参数 @@ -601,6 +625,7 @@ def _run_single_outer_fold( use_mpnn=use_mpnn, mpnn_device=device.type, ) + model.rdkit_encoder._cache = rdkit_cache if pretrain_state_dict is not None and pretrain_config is not None: loaded = load_pretrain_weights_to_model(