加载模型时显式weight_only=False

This commit is contained in:
RYDE-WORK 2026-02-28 16:49:34 +08:00
parent c8c33c1c75
commit 985f3a1bb0
4 changed files with 4 additions and 4 deletions

View File

@ -453,7 +453,7 @@ def main(
if init_from_pretrain is not None: if init_from_pretrain is not None:
if init_from_pretrain.exists(): if init_from_pretrain.exists():
logger.info(f"Loading pretrain weights from {init_from_pretrain}") logger.info(f"Loading pretrain weights from {init_from_pretrain}")
checkpoint = torch.load(init_from_pretrain, map_location="cpu") checkpoint = torch.load(init_from_pretrain, map_location="cpu", weights_only=False)
pretrain_state_dict = checkpoint["model_state_dict"] pretrain_state_dict = checkpoint["model_state_dict"]
pretrain_config = checkpoint.get("config", {}) pretrain_config = checkpoint.get("config", {})
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})") logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")

View File

@ -744,7 +744,7 @@ def main(
if init_from_pretrain is not None: if init_from_pretrain is not None:
if init_from_pretrain.exists(): if init_from_pretrain.exists():
logger.info(f"Loading pretrain weights from {init_from_pretrain}") logger.info(f"Loading pretrain weights from {init_from_pretrain}")
checkpoint = torch.load(init_from_pretrain, map_location="cpu") checkpoint = torch.load(init_from_pretrain, map_location="cpu", weights_only=False)
pretrain_state_dict = checkpoint["model_state_dict"] pretrain_state_dict = checkpoint["model_state_dict"]
pretrain_config = checkpoint.get("config", {}) pretrain_config = checkpoint.get("config", {})
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})") logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")

View File

@ -38,7 +38,7 @@ def load_model(
自动根据 checkpoint config.use_mpnn 选择模型类型 自动根据 checkpoint config.use_mpnn 选择模型类型
""" """
checkpoint = torch.load(model_path, map_location=device) checkpoint = torch.load(model_path, map_location=device, weights_only=False)
config = checkpoint["config"] config = checkpoint["config"]
use_mpnn = config.get("use_mpnn", False) use_mpnn = config.get("use_mpnn", False)

View File

@ -392,7 +392,7 @@ def test(
# 加载模型 # 加载模型
logger.info(f"Loading pretrain model from {model_path}") logger.info(f"Loading pretrain model from {model_path}")
checkpoint = torch.load(model_path, map_location=device_obj) checkpoint = torch.load(model_path, map_location=device_obj, weights_only=False)
config = checkpoint["config"] config = checkpoint["config"]
# 解析 MPNN 配置 # 解析 MPNN 配置