移除旧版函数参数,以及抑制chemprop/torch影响终端输出查看的过多warning

This commit is contained in:
RYDE-WORK 2026-02-28 15:55:47 +08:00
parent 7c69d47238
commit 70b7a4c62a
2 changed files with 16 additions and 11 deletions

View File

@ -1,5 +1,7 @@
"""SMILES 分子特征提取器""" """SMILES 分子特征提取器"""
import logging
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Dict from typing import List, Optional, Dict
import numpy as np import numpy as np
@ -15,6 +17,9 @@ import torch
from chemprop.utils import load_checkpoint from chemprop.utils import load_checkpoint
from chemprop.features import mol2graph from chemprop.features import mol2graph
_chemprop_logger = logging.getLogger("chemprop.load_checkpoint")
_chemprop_logger.setLevel(logging.WARNING)
@dataclass @dataclass
class RDKitFeaturizer: class RDKitFeaturizer:
@ -97,16 +102,16 @@ class MPNNFeaturizer:
device = torch.device(self.device) device = torch.device(self.device)
paths = self.ensemble_paths or [self.checkpoint_path] paths = self.ensemble_paths or [self.checkpoint_path]
for path in paths: with warnings.catch_warnings():
model = load_checkpoint(path, device=device) warnings.filterwarnings("ignore", category=FutureWarning)
model.eval() for path in paths:
# 提取 MPNEncoderD-MPNN 核心部分) model = load_checkpoint(path, device=device, logger=_chemprop_logger)
encoder = model.encoder.encoder[0] model.eval()
# 冻结参数 encoder = model.encoder.encoder[0]
for param in encoder.parameters(): for param in encoder.parameters():
param.requires_grad = False param.requires_grad = False
self._encoders.append(encoder) self._encoders.append(encoder)
self._hidden_size = encoder.hidden_size self._hidden_size = encoder.hidden_size
self._initialized = True self._initialized = True

View File

@ -158,7 +158,7 @@ def pretrain(
model = model.to(device) model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5, verbose=True optimizer, mode="min", factor=0.5, patience=5
) )
early_stopping = EarlyStopping(patience=patience) early_stopping = EarlyStopping(patience=patience)