From 6c4bbad9edbc043e12567a95b346d8edfe0fc64a Mon Sep 17 00:00:00 2001 From: Su Yang Date: Fri, 2 Feb 2024 14:32:09 +0800 Subject: [PATCH 1/2] feat: allow user change the demo host and port --- demo/hf_based_demo.py | 8 +++++++- demo/vllm_based_demo.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 1cd9289..ebdd86a 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -14,6 +14,9 @@ from transformers import ( parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") +parser.add_argument("--server_name", type=str, default="127.0.0.1") +parser.add_argument("--server_port", type=int, default=7860) + args = parser.parse_args() # init model and tokenizer @@ -21,6 +24,9 @@ path = args.model_path tokenizer = AutoTokenizer.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) +# init gradio demo host and port +server_name=args.server_name +server_port=args.server_port def hf_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): """generate model output with huggingface api @@ -151,4 +157,4 @@ with gr.Blocks(theme="soft") as demo: reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot]) demo.queue() -demo.launch(server_name="127.0.0.1", show_error=True) +demo.launch(server_name=server_name, server_port=server_port, show_error=True) diff --git a/demo/vllm_based_demo.py b/demo/vllm_based_demo.py index 3789380..b804f01 100644 --- a/demo/vllm_based_demo.py +++ b/demo/vllm_based_demo.py @@ -9,12 +9,18 @@ from vllm import LLM, SamplingParams parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") +parser.add_argument("--server_name", type=str, default="127.0.0.1") +parser.add_argument("--server_port", type=int, default=7860) + args = parser.parse_args() # init model and tokenizer path = args.model_path llm = LLM(model=path, tensor_parallel_size=1, dtype="bfloat16") +# init gradio demo host and port +server_name=args.server_name +server_port=args.server_port def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): """generate model output with huggingface api @@ -158,4 +164,4 @@ with gr.Blocks(theme="soft") as demo: reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot]) demo.queue() -demo.launch(server_name="127.0.0.1", show_error=True) +demo.launch(server_name=server_name, server_port=server_port, show_error=True) From 2bf9e80bf8ffdd10160a878f34abf32eec1dac30 Mon Sep 17 00:00:00 2001 From: DingDing Date: Fri, 2 Feb 2024 16:54:13 +0800 Subject: [PATCH 2/2] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 928c564..d914982 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ python inference.py --model_path --prompt_path prompts/promp ``` #### Huggingface 模型 -(注:我们发现当前Huggingface的推理代码推理效果差于Vllm的推理代码,我们正在对齐中,请耐心等待) +(注:我们发现当前Huggingface的推理代码推理效果差于Vllm的推理代码,我们正在对齐中,目前已定为到PageAttention和普通attention的区别,请耐心等待) ##### MiniCPM-2B * 安装`transformers>=4.36.0`以及`accelerate`后,运行以下代码 ```python