diff --git a/lnp_ml/featurization/smiles.py b/lnp_ml/featurization/smiles.py index 434c08d..a826a4a 100644 --- a/lnp_ml/featurization/smiles.py +++ b/lnp_ml/featurization/smiles.py @@ -1,5 +1,7 @@ """SMILES 分子特征提取器""" +import logging +import warnings from dataclasses import dataclass, field from typing import List, Optional, Dict import numpy as np @@ -15,6 +17,9 @@ import torch from chemprop.utils import load_checkpoint from chemprop.features import mol2graph +_chemprop_logger = logging.getLogger("chemprop.load_checkpoint") +_chemprop_logger.setLevel(logging.WARNING) + @dataclass class RDKitFeaturizer: @@ -97,16 +102,16 @@ class MPNNFeaturizer: device = torch.device(self.device) paths = self.ensemble_paths or [self.checkpoint_path] - for path in paths: - model = load_checkpoint(path, device=device) - model.eval() - # 提取 MPNEncoder(D-MPNN 核心部分) - encoder = model.encoder.encoder[0] - # 冻结参数 - for param in encoder.parameters(): - param.requires_grad = False - self._encoders.append(encoder) - self._hidden_size = encoder.hidden_size + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + for path in paths: + model = load_checkpoint(path, device=device, logger=_chemprop_logger) + model.eval() + encoder = model.encoder.encoder[0] + for param in encoder.parameters(): + param.requires_grad = False + self._encoders.append(encoder) + self._hidden_size = encoder.hidden_size self._initialized = True diff --git a/lnp_ml/modeling/pretrain.py b/lnp_ml/modeling/pretrain.py index 52dd3f1..9a23288 100644 --- a/lnp_ml/modeling/pretrain.py +++ b/lnp_ml/modeling/pretrain.py @@ -158,7 +158,7 @@ def pretrain( model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, mode="min", factor=0.5, patience=5, verbose=True + optimizer, mode="min", factor=0.5, patience=5 ) early_stopping = EarlyStopping(patience=patience)