mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
微调阶段rdkit缓存共享
This commit is contained in:
parent
f952033b09
commit
7c69d47238
@ -41,7 +41,10 @@ from lnp_ml.dataset import (
|
||||
TARGET_CLASSIFICATION_EE,
|
||||
TARGET_TOXIC,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
from lnp_ml.modeling.encoders.rdkit_encoder import CachedRDKitEncoder
|
||||
from lnp_ml.modeling.trainer_balanced import (
|
||||
ClassWeights,
|
||||
LossWeightsBalanced,
|
||||
@ -140,6 +143,20 @@ def build_composite_strata(
|
||||
return encoded.astype(np.int64), strata_info
|
||||
|
||||
|
||||
# ============ RDKit 缓存预热 ============
|
||||
|
||||
def warmup_rdkit_cache(smiles_list: List[str], batch_size: int = 256) -> Dict:
|
||||
"""预热 RDKit 特征缓存,返回可跨模型共享的缓存字典。"""
|
||||
encoder = CachedRDKitEncoder()
|
||||
unique_smiles = list(set(smiles_list))
|
||||
logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...")
|
||||
for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"):
|
||||
batch = unique_smiles[i:i + batch_size]
|
||||
encoder(batch)
|
||||
logger.success(f"Cache warmup complete. Cached {len(encoder._cache)} SMILES.")
|
||||
return encoder._cache
|
||||
|
||||
|
||||
# ============ 模型创建 ============
|
||||
|
||||
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||
@ -231,6 +248,7 @@ def run_optuna_cv(
|
||||
pretrain_state_dict: Optional[Dict] = None,
|
||||
pretrain_config: Optional[Dict] = None,
|
||||
load_delivery_head: bool = True,
|
||||
rdkit_cache: Optional[Dict] = None,
|
||||
) -> Tuple[Dict, int, optuna.Study]:
|
||||
"""
|
||||
使用全量数据做 3-fold CV Optuna 超参搜索。
|
||||
@ -318,6 +336,8 @@ def run_optuna_cv(
|
||||
use_mpnn=use_mpnn,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
if rdkit_cache is not None:
|
||||
model.rdkit_encoder._cache = rdkit_cache
|
||||
|
||||
# 加载预训练权重
|
||||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||||
@ -464,6 +484,9 @@ def main(
|
||||
# 创建完整数据集
|
||||
full_dataset = LNPDataset(df)
|
||||
|
||||
# 预热 RDKit 缓存(在整个训练流程中共享)
|
||||
rdkit_cache = warmup_rdkit_cache(full_dataset.smiles)
|
||||
|
||||
# 运行 Optuna 调参
|
||||
logger.info(f"\nRunning {n_folds}-fold Optuna with {n_trials} trials...")
|
||||
study_path = output_dir / "optuna_study.sqlite3"
|
||||
@ -483,6 +506,7 @@ def main(
|
||||
pretrain_state_dict=pretrain_state_dict,
|
||||
pretrain_config=pretrain_config,
|
||||
load_delivery_head=load_delivery_head,
|
||||
rdkit_cache=rdkit_cache,
|
||||
)
|
||||
|
||||
# 保存最佳参数
|
||||
@ -541,6 +565,7 @@ def main(
|
||||
use_mpnn=use_mpnn,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
model.rdkit_encoder._cache = rdkit_cache
|
||||
|
||||
# 加载预训练权重
|
||||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||||
|
||||
@ -41,7 +41,10 @@ from lnp_ml.dataset import (
|
||||
TARGET_CLASSIFICATION_EE,
|
||||
TARGET_TOXIC,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from lnp_ml.modeling.models import LNPModel, LNPModelWithoutMPNN
|
||||
from lnp_ml.modeling.encoders.rdkit_encoder import CachedRDKitEncoder
|
||||
from lnp_ml.modeling.trainer_balanced import (
|
||||
ClassWeights,
|
||||
LossWeightsBalanced,
|
||||
@ -142,6 +145,20 @@ def build_composite_strata(
|
||||
return encoded.astype(np.int64), strata_info
|
||||
|
||||
|
||||
# ============ RDKit 缓存预热 ============
|
||||
|
||||
def warmup_rdkit_cache(smiles_list: List[str], batch_size: int = 256) -> Dict:
|
||||
"""预热 RDKit 特征缓存,返回可跨模型共享的缓存字典。"""
|
||||
encoder = CachedRDKitEncoder()
|
||||
unique_smiles = list(set(smiles_list))
|
||||
logger.info(f"Warming up RDKit cache for {len(unique_smiles)} unique SMILES...")
|
||||
for i in tqdm(range(0, len(unique_smiles), batch_size), desc="Cache warmup"):
|
||||
batch = unique_smiles[i:i + batch_size]
|
||||
encoder(batch)
|
||||
logger.success(f"Cache warmup complete. Cached {len(encoder._cache)} SMILES.")
|
||||
return encoder._cache
|
||||
|
||||
|
||||
# ============ 模型创建 ============
|
||||
|
||||
def find_mpnn_ensemble_paths(base_dir: Path = DEFAULT_MPNN_ENSEMBLE_DIR) -> List[str]:
|
||||
@ -344,6 +361,7 @@ def run_inner_optuna(
|
||||
pretrain_state_dict: Optional[Dict] = None,
|
||||
pretrain_config: Optional[Dict] = None,
|
||||
load_delivery_head: bool = True,
|
||||
rdkit_cache: Optional[Dict] = None,
|
||||
) -> Tuple[Dict, int, optuna.Study]:
|
||||
"""
|
||||
在内层数据上运行 Optuna 超参搜索。
|
||||
@ -439,6 +457,8 @@ def run_inner_optuna(
|
||||
use_mpnn=use_mpnn,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
if rdkit_cache is not None:
|
||||
model.rdkit_encoder._cache = rdkit_cache
|
||||
|
||||
# 加载预训练权重
|
||||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||||
@ -540,6 +560,9 @@ def _run_single_outer_fold(
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"Train: {len(outer_train_idx)}, Test: {len(outer_test_idx)}")
|
||||
|
||||
# 预热 RDKit 缓存(在整个 fold 内共享)
|
||||
rdkit_cache = warmup_rdkit_cache(full_dataset.smiles)
|
||||
|
||||
# 保存 split indices
|
||||
with open(fold_dir / "splits.json", "w") as f:
|
||||
json.dump({
|
||||
@ -567,6 +590,7 @@ def _run_single_outer_fold(
|
||||
pretrain_state_dict=pretrain_state_dict,
|
||||
pretrain_config=pretrain_config,
|
||||
load_delivery_head=load_delivery_head,
|
||||
rdkit_cache=rdkit_cache,
|
||||
)
|
||||
|
||||
# 保存最佳参数
|
||||
@ -601,6 +625,7 @@ def _run_single_outer_fold(
|
||||
use_mpnn=use_mpnn,
|
||||
mpnn_device=device.type,
|
||||
)
|
||||
model.rdkit_encoder._cache = rdkit_cache
|
||||
|
||||
if pretrain_state_dict is not None and pretrain_config is not None:
|
||||
loaded = load_pretrain_weights_to_model(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user