diff --git a/README.md b/README.md index 928c564..d914982 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ python inference.py --model_path --prompt_path prompts/promp ``` #### Huggingface 模型 -(注:我们发现当前Huggingface的推理代码推理效果差于Vllm的推理代码,我们正在对齐中,请耐心等待) +(注:我们发现当前Huggingface的推理代码推理效果差于Vllm的推理代码,我们正在对齐中,目前已定为到PageAttention和普通attention的区别,请耐心等待) ##### MiniCPM-2B * 安装`transformers>=4.36.0`以及`accelerate`后,运行以下代码 ```python diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 38c5457..88bda71 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -14,7 +14,9 @@ from transformers import ( parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") -parser.add_argument("--torch_dtype", type=str, default="bfloat16") +parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])) +parser.add_argument("--server_name", type=str, default="127.0.0.1") +parser.add_argument("--server_port", type=int, default=7860) args = parser.parse_args() # init model torch dtype @@ -31,6 +33,9 @@ path = args.model_path tokenizer = AutoTokenizer.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True) +# init gradio demo host and port +server_name=args.server_name +server_port=args.server_port def hf_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): """generate model output with huggingface api @@ -161,4 +166,4 @@ with gr.Blocks(theme="soft") as demo: reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot]) demo.queue() -demo.launch(server_name="127.0.0.1", show_error=True) +demo.launch(server_name=server_name, server_port=server_port, show_error=True) diff --git a/demo/vllm_based_demo.py b/demo/vllm_based_demo.py index 802c79c..0180405 100644 --- a/demo/vllm_based_demo.py +++ b/demo/vllm_based_demo.py @@ -9,7 +9,9 @@ from vllm import LLM, SamplingParams parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="") -parser.add_argument("--torch_dtype", type=str, default="bfloat16") +parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])) +parser.add_argument("--server_name", type=str, default="127.0.0.1") +parser.add_argument("--server_port", type=int, default=7860) args = parser.parse_args() @@ -26,6 +28,9 @@ else: path = args.model_path llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype) +# init gradio demo host and port +server_name=args.server_name +server_port=args.server_port def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): """generate model output with huggingface api @@ -169,4 +174,4 @@ with gr.Blocks(theme="soft") as demo: reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot]) demo.queue() -demo.launch(server_name="127.0.0.1", show_error=True) +demo.launch(server_name=server_name, server_port=server_port, show_error=True)