diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 1cd9289..ebdd86a 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -14,6 +14,9 @@ from transformers import ( parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") +parser.add_argument("--server_name", type=str, default="127.0.0.1") +parser.add_argument("--server_port", type=int, default=7860) + args = parser.parse_args() # init model and tokenizer @@ -21,6 +24,9 @@ path = args.model_path tokenizer = AutoTokenizer.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) +# init gradio demo host and port +server_name=args.server_name +server_port=args.server_port def hf_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): """generate model output with huggingface api @@ -151,4 +157,4 @@ with gr.Blocks(theme="soft") as demo: reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot]) demo.queue() -demo.launch(server_name="127.0.0.1", show_error=True) +demo.launch(server_name=server_name, server_port=server_port, show_error=True) diff --git a/demo/vllm_based_demo.py b/demo/vllm_based_demo.py index 3789380..b804f01 100644 --- a/demo/vllm_based_demo.py +++ b/demo/vllm_based_demo.py @@ -9,12 +9,18 @@ from vllm import LLM, SamplingParams parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") +parser.add_argument("--server_name", type=str, default="127.0.0.1") +parser.add_argument("--server_port", type=int, default=7860) + args = parser.parse_args() # init model and tokenizer path = args.model_path llm = LLM(model=path, tensor_parallel_size=1, dtype="bfloat16") +# init gradio demo host and port +server_name=args.server_name +server_port=args.server_port def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): """generate model output with huggingface api @@ -158,4 +164,4 @@ with gr.Blocks(theme="soft") as demo: reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot]) demo.queue() -demo.launch(server_name="127.0.0.1", show_error=True) +demo.launch(server_name=server_name, server_port=server_port, show_error=True)