fix device conflicts

This commit is contained in:
RYDE-WORK 2026-01-19 11:33:57 +08:00
parent 0290649df1
commit cbfbd1a7af

View File

@ -132,6 +132,9 @@ class LNPModel(nn.Module):
Returns: Returns:
stacked: [B, n_tokens, d_model] stacked: [B, n_tokens, d_model]
""" """
# 获取目标设备(从 tabular 数据推断)
device = tabular["comp"].device
# 1. Encode SMILES # 1. Encode SMILES
rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"} rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"}
@ -141,14 +144,14 @@ class LNPModel(nn.Module):
# MPNN 特征(如果启用) # MPNN 特征(如果启用)
if self.use_mpnn: if self.use_mpnn:
mpnn_features = self.mpnn_encoder(smiles) mpnn_features = self.mpnn_encoder(smiles)
all_features["mpnn"] = mpnn_features["mpnn"] all_features["mpnn"] = mpnn_features["mpnn"].to(device)
# RDKit 特征 # RDKit 特征(移到正确设备)
all_features["morgan"] = rdkit_features["morgan"] all_features["morgan"] = rdkit_features["morgan"].to(device)
all_features["maccs"] = rdkit_features["maccs"] all_features["maccs"] = rdkit_features["maccs"].to(device)
all_features["desc"] = rdkit_features["desc"] all_features["desc"] = rdkit_features["desc"].to(device)
# Tabular 特征 # Tabular 特征(已在正确设备上)
all_features["comp"] = tabular["comp"] all_features["comp"] = tabular["comp"]
all_features["phys"] = tabular["phys"] all_features["phys"] = tabular["phys"]
all_features["help"] = tabular["help"] all_features["help"] = tabular["help"]