diff --git a/demo/vllm_based_demo.py b/demo/vllm_based_demo.py index c5ce8c0..efefb37 100644 --- a/demo/vllm_based_demo.py +++ b/demo/vllm_based_demo.py @@ -1,36 +1,46 @@ -from typing import Dict from typing import List -from typing import Tuple - import argparse import gradio as gr from vllm import LLM, SamplingParams - +import torch +from transformers import AutoTokenizer parser = argparse.ArgumentParser() -parser.add_argument("--model_path", type=str, default="") +parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-1B-sft-bf16") parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]) parser.add_argument("--server_name", type=str, default="127.0.0.1") parser.add_argument("--server_port", type=int, default=7860) -args = parser.parse_args() +parser.add_argument("--max_tokens", type=int, default=2048) +# for MiniCPM-1B and MiniCPM-2B model, max_tokens should be set to 2048 +args = parser.parse_args() # init model torch dtype torch_dtype = args.torch_dtype -if torch_dtype =="" or torch_dtype == "bfloat16": - torch_dtype = "bfloat16" +if torch_dtype == "" or torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 elif torch_dtype == "float32": - torch_dtype = "float32" + torch_dtype = torch.float32 +elif torch_dtype == "float16": + torch_dtype = torch.float16 else: raise ValueError(f"Invalid torch dtype: {torch_dtype}") # init model and tokenizer path = args.model_path -llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype) +llm = LLM( + model=path, + tensor_parallel_size=1, + dtype=torch_dtype, + trust_remote_code=True, + gpu_memory_utilization=0.9, + max_model_len=args.max_tokens +) +tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + +server_name = args.server_name +server_port = args.server_port -# init gradio demo host and port -server_name=args.server_name -server_port=args.server_port def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): """generate model output with huggingface api @@ -43,19 +53,13 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): Yields: str: real-time generation results of hf model - """ - prompt = "" + """ assert len(dialog) % 2 == 1 - for info in dialog: - if info["role"] == "user": - prompt += "<用户>" + info["content"] - else: - prompt += "" + info["content"] - prompt += "" + prompt = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False) params_dict = { "n": 1, "best_of": 1, - "presence_penalty": 1.0, + "presence_penalty": 1.0, "frequency_penalty": 0.0, "temperature": temperature, "top_p": top_p, @@ -89,7 +93,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, m Yields: List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round. - """ + """ assert query != "", "Input must not be empty!!!" # apply chat template model_input = [] @@ -114,7 +118,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, max_dec_len Yields: List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history - """ + """ assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!" # apply chat template model_input = [] @@ -133,7 +137,7 @@ def clear_history(): Returns: List: empty chat history - """ + """ return [] @@ -145,7 +149,7 @@ def reverse_last_round(chat_history): Returns: List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round. - """ + """ assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!" return chat_history[:-1] @@ -158,7 +162,7 @@ with gr.Blocks(theme="soft") as demo: with gr.Column(scale=1): top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p") temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature") - max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len") + max_dec_len = gr.Slider(1, args.max_tokens, value=args.max_tokens, step=1, label="max_tokens") with gr.Column(scale=5): chatbot = gr.Chatbot(bubble_full_width=False, height=400) user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8) diff --git a/requirements.txt b/requirements.txt index ffed218..e6cdf7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,9 @@ torch>=2.0.0 transformers>=4.36.2 gradio>=4.26.0 +# for vllm inference +# vllm>=0.4.0.post1 + # for openai api inference openai>=1.17.1 tiktoken>=0.6.0