diff --git a/lnp_ml/modeling/final_train_optuna_cv.py b/lnp_ml/modeling/final_train_optuna_cv.py index 9e36152..9e2173f 100644 --- a/lnp_ml/modeling/final_train_optuna_cv.py +++ b/lnp_ml/modeling/final_train_optuna_cv.py @@ -453,7 +453,7 @@ def main( if init_from_pretrain is not None: if init_from_pretrain.exists(): logger.info(f"Loading pretrain weights from {init_from_pretrain}") - checkpoint = torch.load(init_from_pretrain, map_location="cpu") + checkpoint = torch.load(init_from_pretrain, map_location="cpu", weights_only=False) pretrain_state_dict = checkpoint["model_state_dict"] pretrain_config = checkpoint.get("config", {}) logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})") diff --git a/lnp_ml/modeling/nested_cv_optuna.py b/lnp_ml/modeling/nested_cv_optuna.py index c2c24e4..8cc26d0 100644 --- a/lnp_ml/modeling/nested_cv_optuna.py +++ b/lnp_ml/modeling/nested_cv_optuna.py @@ -744,7 +744,7 @@ def main( if init_from_pretrain is not None: if init_from_pretrain.exists(): logger.info(f"Loading pretrain weights from {init_from_pretrain}") - checkpoint = torch.load(init_from_pretrain, map_location="cpu") + checkpoint = torch.load(init_from_pretrain, map_location="cpu", weights_only=False) pretrain_state_dict = checkpoint["model_state_dict"] pretrain_config = checkpoint.get("config", {}) logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})") diff --git a/lnp_ml/modeling/predict.py b/lnp_ml/modeling/predict.py index 88facd5..6722da8 100644 --- a/lnp_ml/modeling/predict.py +++ b/lnp_ml/modeling/predict.py @@ -38,7 +38,7 @@ def load_model( 自动根据 checkpoint 的 config.use_mpnn 选择模型类型。 """ - checkpoint = torch.load(model_path, map_location=device) + checkpoint = torch.load(model_path, map_location=device, weights_only=False) config = checkpoint["config"] use_mpnn = config.get("use_mpnn", False) diff --git a/lnp_ml/modeling/pretrain.py b/lnp_ml/modeling/pretrain.py index 9a23288..abb702d 100644 --- a/lnp_ml/modeling/pretrain.py +++ b/lnp_ml/modeling/pretrain.py @@ -392,7 +392,7 @@ def test( # 加载模型 logger.info(f"Loading pretrain model from {model_path}") - checkpoint = torch.load(model_path, map_location=device_obj) + checkpoint = torch.load(model_path, map_location=device_obj, weights_only=False) config = checkpoint["config"] # 解析 MPNN 配置