移除旧版函数参数,以及抑制chemprop/torch影响终端输出查看的过多warning

This commit is contained in:
RYDE-WORK 2026-02-28 15:55:47 +08:00
parent 7c69d47238
commit 70b7a4c62a
2 changed files with 16 additions and 11 deletions

View File

@ -1,5 +1,7 @@
"""SMILES 分子特征提取器"""
import logging
import warnings
from dataclasses import dataclass, field
from typing import List, Optional, Dict
import numpy as np
@ -15,6 +17,9 @@ import torch
from chemprop.utils import load_checkpoint
from chemprop.features import mol2graph
_chemprop_logger = logging.getLogger("chemprop.load_checkpoint")
_chemprop_logger.setLevel(logging.WARNING)
@dataclass
class RDKitFeaturizer:
@ -97,16 +102,16 @@ class MPNNFeaturizer:
device = torch.device(self.device)
paths = self.ensemble_paths or [self.checkpoint_path]
for path in paths:
model = load_checkpoint(path, device=device)
model.eval()
# 提取 MPNEncoderD-MPNN 核心部分)
encoder = model.encoder.encoder[0]
# 冻结参数
for param in encoder.parameters():
param.requires_grad = False
self._encoders.append(encoder)
self._hidden_size = encoder.hidden_size
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
for path in paths:
model = load_checkpoint(path, device=device, logger=_chemprop_logger)
model.eval()
encoder = model.encoder.encoder[0]
for param in encoder.parameters():
param.requires_grad = False
self._encoders.append(encoder)
self._hidden_size = encoder.hidden_size
self._initialized = True

View File

@ -158,7 +158,7 @@ def pretrain(
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5, verbose=True
optimizer, mode="min", factor=0.5, patience=5
)
early_stopping = EarlyStopping(patience=patience)