mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-01-19 11:53:13 +08:00
fix device conflicts
This commit is contained in:
parent
0290649df1
commit
cbfbd1a7af
@ -132,6 +132,9 @@ class LNPModel(nn.Module):
|
||||
Returns:
|
||||
stacked: [B, n_tokens, d_model]
|
||||
"""
|
||||
# 获取目标设备(从 tabular 数据推断)
|
||||
device = tabular["comp"].device
|
||||
|
||||
# 1. Encode SMILES
|
||||
rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"}
|
||||
|
||||
@ -141,14 +144,14 @@ class LNPModel(nn.Module):
|
||||
# MPNN 特征(如果启用)
|
||||
if self.use_mpnn:
|
||||
mpnn_features = self.mpnn_encoder(smiles)
|
||||
all_features["mpnn"] = mpnn_features["mpnn"]
|
||||
all_features["mpnn"] = mpnn_features["mpnn"].to(device)
|
||||
|
||||
# RDKit 特征
|
||||
all_features["morgan"] = rdkit_features["morgan"]
|
||||
all_features["maccs"] = rdkit_features["maccs"]
|
||||
all_features["desc"] = rdkit_features["desc"]
|
||||
# RDKit 特征(移到正确设备)
|
||||
all_features["morgan"] = rdkit_features["morgan"].to(device)
|
||||
all_features["maccs"] = rdkit_features["maccs"].to(device)
|
||||
all_features["desc"] = rdkit_features["desc"].to(device)
|
||||
|
||||
# Tabular 特征
|
||||
# Tabular 特征(已在正确设备上)
|
||||
all_features["comp"] = tabular["comp"]
|
||||
all_features["phys"] = tabular["phys"]
|
||||
all_features["help"] = tabular["help"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user