From cbfbd1a7afaabc1bcd1489a196096130b9af78d3 Mon Sep 17 00:00:00 2001 From: RYDE-WORK Date: Mon, 19 Jan 2026 11:33:57 +0800 Subject: [PATCH] fix device conflicts --- lnp_ml/modeling/models.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/lnp_ml/modeling/models.py b/lnp_ml/modeling/models.py index 4cd60a7..8b64333 100644 --- a/lnp_ml/modeling/models.py +++ b/lnp_ml/modeling/models.py @@ -132,6 +132,9 @@ class LNPModel(nn.Module): Returns: stacked: [B, n_tokens, d_model] """ + # 获取目标设备(从 tabular 数据推断) + device = tabular["comp"].device + # 1. Encode SMILES rdkit_features = self.rdkit_encoder(smiles) # {"morgan", "maccs", "desc"} @@ -141,14 +144,14 @@ class LNPModel(nn.Module): # MPNN 特征(如果启用) if self.use_mpnn: mpnn_features = self.mpnn_encoder(smiles) - all_features["mpnn"] = mpnn_features["mpnn"] + all_features["mpnn"] = mpnn_features["mpnn"].to(device) - # RDKit 特征 - all_features["morgan"] = rdkit_features["morgan"] - all_features["maccs"] = rdkit_features["maccs"] - all_features["desc"] = rdkit_features["desc"] + # RDKit 特征(移到正确设备) + all_features["morgan"] = rdkit_features["morgan"].to(device) + all_features["maccs"] = rdkit_features["maccs"].to(device) + all_features["desc"] = rdkit_features["desc"].to(device) - # Tabular 特征 + # Tabular 特征(已在正确设备上) all_features["comp"] = tabular["comp"] all_features["phys"] = tabular["phys"] all_features["help"] = tabular["help"]