mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
移除旧版函数参数,以及抑制chemprop/torch影响终端输出查看的过多warning
This commit is contained in:
parent
7c69d47238
commit
70b7a4c62a
@ -1,5 +1,7 @@
|
||||
"""SMILES 分子特征提取器"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict
|
||||
import numpy as np
|
||||
@ -15,6 +17,9 @@ import torch
|
||||
from chemprop.utils import load_checkpoint
|
||||
from chemprop.features import mol2graph
|
||||
|
||||
_chemprop_logger = logging.getLogger("chemprop.load_checkpoint")
|
||||
_chemprop_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RDKitFeaturizer:
|
||||
@ -97,12 +102,12 @@ class MPNNFeaturizer:
|
||||
device = torch.device(self.device)
|
||||
paths = self.ensemble_paths or [self.checkpoint_path]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
for path in paths:
|
||||
model = load_checkpoint(path, device=device)
|
||||
model = load_checkpoint(path, device=device, logger=_chemprop_logger)
|
||||
model.eval()
|
||||
# 提取 MPNEncoder(D-MPNN 核心部分)
|
||||
encoder = model.encoder.encoder[0]
|
||||
# 冻结参数
|
||||
for param in encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self._encoders.append(encoder)
|
||||
|
||||
@ -158,7 +158,7 @@ def pretrain(
|
||||
model = model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode="min", factor=0.5, patience=5, verbose=True
|
||||
optimizer, mode="min", factor=0.5, patience=5
|
||||
)
|
||||
early_stopping = EarlyStopping(patience=patience)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user