diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 75df156..4eaa184 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -34,7 +34,7 @@ else: # init model and tokenizer path = args.model_path tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) -model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="cuda:0", trust_remote_code=True) model_architectures = model.config.architectures[0]