mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
移除旧版函数参数,以及抑制chemprop/torch影响终端输出查看的过多warning
This commit is contained in:
parent
7c69d47238
commit
70b7a4c62a
@ -1,5 +1,7 @@
|
|||||||
"""SMILES 分子特征提取器"""
|
"""SMILES 分子特征提取器"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -15,6 +17,9 @@ import torch
|
|||||||
from chemprop.utils import load_checkpoint
|
from chemprop.utils import load_checkpoint
|
||||||
from chemprop.features import mol2graph
|
from chemprop.features import mol2graph
|
||||||
|
|
||||||
|
_chemprop_logger = logging.getLogger("chemprop.load_checkpoint")
|
||||||
|
_chemprop_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RDKitFeaturizer:
|
class RDKitFeaturizer:
|
||||||
@ -97,12 +102,12 @@ class MPNNFeaturizer:
|
|||||||
device = torch.device(self.device)
|
device = torch.device(self.device)
|
||||||
paths = self.ensemble_paths or [self.checkpoint_path]
|
paths = self.ensemble_paths or [self.checkpoint_path]
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
for path in paths:
|
for path in paths:
|
||||||
model = load_checkpoint(path, device=device)
|
model = load_checkpoint(path, device=device, logger=_chemprop_logger)
|
||||||
model.eval()
|
model.eval()
|
||||||
# 提取 MPNEncoder(D-MPNN 核心部分)
|
|
||||||
encoder = model.encoder.encoder[0]
|
encoder = model.encoder.encoder[0]
|
||||||
# 冻结参数
|
|
||||||
for param in encoder.parameters():
|
for param in encoder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
self._encoders.append(encoder)
|
self._encoders.append(encoder)
|
||||||
|
|||||||
@ -158,7 +158,7 @@ def pretrain(
|
|||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
optimizer, mode="min", factor=0.5, patience=5, verbose=True
|
optimizer, mode="min", factor=0.5, patience=5
|
||||||
)
|
)
|
||||||
early_stopping = EarlyStopping(patience=patience)
|
early_stopping = EarlyStopping(patience=patience)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user