diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 1cd9289..38c5457 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -14,12 +14,22 @@ from transformers import ( parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") +parser.add_argument("--torch_dtype", type=str, default="bfloat16") args = parser.parse_args() +# init model torch dtype +torch_dtype = args.torch_dtype +if torch_dtype =="" or torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 +elif torch_dtype == "float32": + torch_dtype = torch.float32 +else: + raise ValueError(f"Invalid torch dtype: {torch_dtype}") + # init model and tokenizer path = args.model_path tokenizer = AutoTokenizer.from_pretrained(path) -model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True) def hf_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): diff --git a/demo/vllm_based_demo.py b/demo/vllm_based_demo.py index 3789380..802c79c 100644 --- a/demo/vllm_based_demo.py +++ b/demo/vllm_based_demo.py @@ -9,11 +9,22 @@ from vllm import LLM, SamplingParams parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") +parser.add_argument("--torch_dtype", type=str, default="bfloat16") args = parser.parse_args() + +# init model torch dtype +torch_dtype = args.torch_dtype +if torch_dtype =="" or torch_dtype == "bfloat16": + torch_dtype = "bfloat16" +elif torch_dtype == "float32": + torch_dtype = "float32" +else: + raise ValueError(f"Invalid torch dtype: {torch_dtype}") + # init model and tokenizer path = args.model_path -llm = LLM(model=path, tensor_parallel_size=1, dtype="bfloat16") +llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype) def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):