更新VLLM的书写方式

This commit is contained in:
zR 2024-04-14 14:23:11 +08:00
parent e58d99f8ca
commit 8272667430
2 changed files with 34 additions and 27 deletions

View File

@ -1,36 +1,46 @@
from typing import Dict
from typing import List from typing import List
from typing import Tuple
import argparse import argparse
import gradio as gr import gradio as gr
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
import torch
from transformers import AutoTokenizer
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="") parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-1B-sft-bf16")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]) parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1") parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860) parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args() parser.add_argument("--max_tokens", type=int, default=2048)
# for MiniCPM-1B and MiniCPM-2B model, max_tokens should be set to 2048
args = parser.parse_args()
# init model torch dtype # init model torch dtype
torch_dtype = args.torch_dtype torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16": if torch_dtype == "" or torch_dtype == "bfloat16":
torch_dtype = "bfloat16" torch_dtype = torch.bfloat16
elif torch_dtype == "float32": elif torch_dtype == "float32":
torch_dtype = "float32" torch_dtype = torch.float32
elif torch_dtype == "float16":
torch_dtype = torch.float16
else: else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}") raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer # init model and tokenizer
path = args.model_path path = args.model_path
llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype) llm = LLM(
model=path,
tensor_parallel_size=1,
dtype=torch_dtype,
trust_remote_code=True,
gpu_memory_utilization=0.9,
max_model_len=args.max_tokens
)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
server_name = args.server_name
server_port = args.server_port
# init gradio demo host and port
server_name=args.server_name
server_port=args.server_port
def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
"""generate model output with huggingface api """generate model output with huggingface api
@ -43,19 +53,13 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
Yields: Yields:
str: real-time generation results of hf model str: real-time generation results of hf model
""" """
prompt = ""
assert len(dialog) % 2 == 1 assert len(dialog) % 2 == 1
for info in dialog: prompt = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
if info["role"] == "user":
prompt += "<用户>" + info["content"]
else:
prompt += "<AI>" + info["content"]
prompt += "<AI>"
params_dict = { params_dict = {
"n": 1, "n": 1,
"best_of": 1, "best_of": 1,
"presence_penalty": 1.0, "presence_penalty": 1.0,
"frequency_penalty": 0.0, "frequency_penalty": 0.0,
"temperature": temperature, "temperature": temperature,
"top_p": top_p, "top_p": top_p,
@ -89,7 +93,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, m
Yields: Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round. List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
""" """
assert query != "", "Input must not be empty!!!" assert query != "", "Input must not be empty!!!"
# apply chat template # apply chat template
model_input = [] model_input = []
@ -114,7 +118,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, max_dec_len
Yields: Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
""" """
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!" assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template # apply chat template
model_input = [] model_input = []
@ -133,7 +137,7 @@ def clear_history():
Returns: Returns:
List: empty chat history List: empty chat history
""" """
return [] return []
@ -145,7 +149,7 @@ def reverse_last_round(chat_history):
Returns: Returns:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round. List: [[q_1, a_1], [q_2, a_2], ..., [q_n-1, a_n-1]]. chat_history without last round.
""" """
assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!" assert len(chat_history) >= 1, "History is empty. Nothing to reverse!!"
return chat_history[:-1] return chat_history[:-1]
@ -158,7 +162,7 @@ with gr.Blocks(theme="soft") as demo:
with gr.Column(scale=1): with gr.Column(scale=1):
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p") top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature") temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len") max_dec_len = gr.Slider(1, args.max_tokens, value=args.max_tokens, step=1, label="max_tokens")
with gr.Column(scale=5): with gr.Column(scale=5):
chatbot = gr.Chatbot(bubble_full_width=False, height=400) chatbot = gr.Chatbot(bubble_full_width=False, height=400)
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8) user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)

View File

@ -3,6 +3,9 @@ torch>=2.0.0
transformers>=4.36.2 transformers>=4.36.2
gradio>=4.26.0 gradio>=4.26.0
# for vllm inference
# vllm>=0.4.0.post1
# for openai api inference # for openai api inference
openai>=1.17.1 openai>=1.17.1
tiktoken>=0.6.0 tiktoken>=0.6.0