更新VLLM的书写方式

This commit is contained in:
zR 2024-04-14 14:23:11 +08:00
parent e58d99f8ca
commit 8272667430
2 changed files with 34 additions and 27 deletions

View File

@ -1,37 +1,47 @@
from typing import Dict
from typing import List from typing import List
from typing import Tuple
import argparse import argparse
import gradio as gr import gradio as gr
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
import torch
from transformers import AutoTokenizer
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="") parser.add_argument("--model_path", type=str, default="openbmb/MiniCPM-1B-sft-bf16")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"]) parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "bfloat16"])
parser.add_argument("--server_name", type=str, default="127.0.0.1") parser.add_argument("--server_name", type=str, default="127.0.0.1")
parser.add_argument("--server_port", type=int, default=7860) parser.add_argument("--server_port", type=int, default=7860)
args = parser.parse_args() parser.add_argument("--max_tokens", type=int, default=2048)
# for MiniCPM-1B and MiniCPM-2B model, max_tokens should be set to 2048
args = parser.parse_args()
# init model torch dtype # init model torch dtype
torch_dtype = args.torch_dtype torch_dtype = args.torch_dtype
if torch_dtype == "" or torch_dtype == "bfloat16": if torch_dtype == "" or torch_dtype == "bfloat16":
torch_dtype = "bfloat16" torch_dtype = torch.bfloat16
elif torch_dtype == "float32": elif torch_dtype == "float32":
torch_dtype = "float32" torch_dtype = torch.float32
elif torch_dtype == "float16":
torch_dtype = torch.float16
else: else:
raise ValueError(f"Invalid torch dtype: {torch_dtype}") raise ValueError(f"Invalid torch dtype: {torch_dtype}")
# init model and tokenizer # init model and tokenizer
path = args.model_path path = args.model_path
llm = LLM(model=path, tensor_parallel_size=1, dtype=torch_dtype) llm = LLM(
model=path,
tensor_parallel_size=1,
dtype=torch_dtype,
trust_remote_code=True,
gpu_memory_utilization=0.9,
max_model_len=args.max_tokens
)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
# init gradio demo host and port
server_name = args.server_name server_name = args.server_name
server_port = args.server_port server_port = args.server_port
def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int): def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
"""generate model output with huggingface api """generate model output with huggingface api
@ -44,14 +54,8 @@ def vllm_gen(dialog: List, top_p: float, temperature: float, max_dec_len: int):
Yields: Yields:
str: real-time generation results of hf model str: real-time generation results of hf model
""" """
prompt = ""
assert len(dialog) % 2 == 1 assert len(dialog) % 2 == 1
for info in dialog: prompt = tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=False)
if info["role"] == "user":
prompt += "<用户>" + info["content"]
else:
prompt += "<AI>" + info["content"]
prompt += "<AI>"
params_dict = { params_dict = {
"n": 1, "n": 1,
"best_of": 1, "best_of": 1,
@ -158,7 +162,7 @@ with gr.Blocks(theme="soft") as demo:
with gr.Column(scale=1): with gr.Column(scale=1):
top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p") top_p = gr.Slider(0, 1, value=0.8, step=0.1, label="top_p")
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature") temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len") max_dec_len = gr.Slider(1, args.max_tokens, value=args.max_tokens, step=1, label="max_tokens")
with gr.Column(scale=5): with gr.Column(scale=5):
chatbot = gr.Chatbot(bubble_full_width=False, height=400) chatbot = gr.Chatbot(bubble_full_width=False, height=400)
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8) user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)

View File

@ -3,6 +3,9 @@ torch>=2.0.0
transformers>=4.36.2 transformers>=4.36.2
gradio>=4.26.0 gradio>=4.26.0
# for vllm inference
# vllm>=0.4.0.post1
# for openai api inference # for openai api inference
openai>=1.17.1 openai>=1.17.1
tiktoken>=0.6.0 tiktoken>=0.6.0