mirror of
https://github.com/RYDE-WORK/lnp_ml.git
synced 2026-03-21 09:36:32 +08:00
加载模型时显式weight_only=False
This commit is contained in:
parent
c8c33c1c75
commit
985f3a1bb0
@ -453,7 +453,7 @@ def main(
|
||||
if init_from_pretrain is not None:
|
||||
if init_from_pretrain.exists():
|
||||
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
|
||||
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
|
||||
checkpoint = torch.load(init_from_pretrain, map_location="cpu", weights_only=False)
|
||||
pretrain_state_dict = checkpoint["model_state_dict"]
|
||||
pretrain_config = checkpoint.get("config", {})
|
||||
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")
|
||||
|
||||
@ -744,7 +744,7 @@ def main(
|
||||
if init_from_pretrain is not None:
|
||||
if init_from_pretrain.exists():
|
||||
logger.info(f"Loading pretrain weights from {init_from_pretrain}")
|
||||
checkpoint = torch.load(init_from_pretrain, map_location="cpu")
|
||||
checkpoint = torch.load(init_from_pretrain, map_location="cpu", weights_only=False)
|
||||
pretrain_state_dict = checkpoint["model_state_dict"]
|
||||
pretrain_config = checkpoint.get("config", {})
|
||||
logger.success(f"Loaded pretrain checkpoint (d_model={pretrain_config.get('d_model')})")
|
||||
|
||||
@ -38,7 +38,7 @@ def load_model(
|
||||
|
||||
自动根据 checkpoint 的 config.use_mpnn 选择模型类型。
|
||||
"""
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
config = checkpoint["config"]
|
||||
use_mpnn = config.get("use_mpnn", False)
|
||||
|
||||
|
||||
@ -392,7 +392,7 @@ def test(
|
||||
|
||||
# 加载模型
|
||||
logger.info(f"Loading pretrain model from {model_path}")
|
||||
checkpoint = torch.load(model_path, map_location=device_obj)
|
||||
checkpoint = torch.load(model_path, map_location=device_obj, weights_only=False)
|
||||
config = checkpoint["config"]
|
||||
|
||||
# 解析 MPNN 配置
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user