mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-01-19 11:53:13 +08:00
fix device conflicts
This commit is contained in:
parent
0290649df1
commit
cbfbd1a7af
@ -132,6 +132,9 @@ class LNPModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
stacked: [B, n_tokens, d_model]
|
stacked: [B, n_tokens, d_model]
|
||||||
"""
|
"""
|
||||||
|
# 获取目标设备(从 tabular 数据推断)
|
||||||
|
device = tabular["comp"].device
|
||||||
|
|
||||||
# 1. Encode SMILES
|
# 1. Encode SMILES
|
||||||
rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"}
|
rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"}
|
||||||
|
|
||||||
@ -141,14 +144,14 @@ class LNPModel(nn.Module):
|
|||||||
# MPNN 特征(如果启用)
|
# MPNN 特征(如果启用)
|
||||||
if self.use_mpnn:
|
if self.use_mpnn:
|
||||||
mpnn_features = self.mpnn_encoder(smiles)
|
mpnn_features = self.mpnn_encoder(smiles)
|
||||||
all_features["mpnn"] = mpnn_features["mpnn"]
|
all_features["mpnn"] = mpnn_features["mpnn"].to(device)
|
||||||
|
|
||||||
# RDKit 特征
|
# RDKit 特征(移到正确设备)
|
||||||
all_features["morgan"] = rdkit_features["morgan"]
|
all_features["morgan"] = rdkit_features["morgan"].to(device)
|
||||||
all_features["maccs"] = rdkit_features["maccs"]
|
all_features["maccs"] = rdkit_features["maccs"].to(device)
|
||||||
all_features["desc"] = rdkit_features["desc"]
|
all_features["desc"] = rdkit_features["desc"].to(device)
|
||||||
|
|
||||||
# Tabular 特征
|
# Tabular 特征(已在正确设备上)
|
||||||
all_features["comp"] = tabular["comp"]
|
all_features["comp"] = tabular["comp"]
|
||||||
all_features["phys"] = tabular["phys"]
|
all_features["phys"] = tabular["phys"]
|
||||||
all_features["help"] = tabular["help"]
|
all_features["help"] = tabular["help"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user