diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 88bda71..12ff697 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -14,7 +14,7 @@ from transformers import ( parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") -parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])) +parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]) parser.add_argument("--server_name", type=str, default="127.0.0.1") parser.add_argument("--server_port", type=int, default=7860) args = parser.parse_args() diff --git a/demo/vllm_based_demo.py b/demo/vllm_based_demo.py index 0180405..c5ce8c0 100644 --- a/demo/vllm_based_demo.py +++ b/demo/vllm_based_demo.py @@ -9,7 +9,7 @@ from vllm import LLM, SamplingParams parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") -parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])) +parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]) parser.add_argument("--server_name", type=str, default="127.0.0.1") parser.add_argument("--server_port", type=int, default=7860) args = parser.parse_args()