add MiniCPMV in hf_demo

This commit is contained in:
winter 2024-03-27 16:30:09 +08:00
parent a1013b1ad2
commit ce391fb099

View File

@ -1,13 +1,13 @@
from typing import Dict
from typing import List from typing import List
from typing import Tuple
import argparse import argparse
import gradio as gr import gradio as gr
import torch import torch
from threading import Thread from threading import Thread
from PIL import Image
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
TextIteratorStreamer TextIteratorStreamer
) )
@ -23,7 +23,7 @@ args = parser.parse_args()
# init model torch dtype # init model torch dtype
torch_dtype = args.torch_dtype torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16": if torch_dtype == "" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
elif torch_dtype == "float32": elif torch_dtype == "float32":
torch_dtype = torch.float32 torch_dtype = torch.float32
@ -32,12 +32,36 @@ else:
# init model and tokenizer # init model and tokenizer
path = args.model_path path = args.model_path
tokenizer = AutoTokenizer.from_pretrained(path) tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
model_architectures = model.config.architectures[0]
def check_model_v(img_file_path: str = None):
'''
check model is MiniCPMV
Args:
img_file_path (str): Image filepath
Returns:
Ture if model is MiniCPMV else False
'''
if model_architectures == "MiniCPMV":
return True
if isinstance(img_file_path, str):
gr.Warning('Only MiniCPMV model can support Image')
return False
if check_model_v():
model = model.to(dtype=torch.bfloat16)
# init gradio demo host and port # init gradio demo host and port
server_name=args.server_name server_name = args.server_name
server_port=args.server_port server_port = args.server_port
def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
"""generate model output with huggingface api """generate model output with huggingface api
@ -73,7 +97,40 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
yield answer[4 + len(inputs):] yield answer[4 + len(inputs):]
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): def hf_v_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int,
img_file_path: str):
"""generate model output with huggingface api
Args:
query (str): actual model input.
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): Strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
img_file_path (str): Image filepath.
Yields:
str: real-time generation results of hf model
"""
assert isinstance(img_file_path, str), 'Image must not be empty'
img = Image.open(img_file_path).convert('RGB')
generation_kwargs = dict(
image=img,
msgs=dialog,
context=None,
tokenizer=tokenizer,
sampling=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
max_new_tokens=max_dec_len
)
res, context, _ = model.chat(**generation_kwargs)
return res
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int,
img_file_path: str = None):
"""generate after hitting "submit" button """generate after hitting "submit" button
Args: Args:
@ -82,10 +139,11 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution. temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate. max_dec_len (int): The maximum numbers of tokens to generate.
img_file_path (str): Image filepath.
Yields: Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round. List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
""" """
assert query != "", "Input must not be empty!!!" assert query != "", "Input must not be empty!!!"
# apply chat template # apply chat template
model_input = [] model_input = []
@ -95,12 +153,18 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r
model_input.append({"role": "user", "content": query}) model_input.append({"role": "user", "content": query})
# yield model generation # yield model generation
chat_history.append([query, ""]) chat_history.append([query, ""])
if check_model_v():
chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path)
yield gr.update(value=""), chat_history
return
for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len): for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len):
chat_history[-1][1] = answer.strip("</s>") chat_history[-1][1] = answer.strip("</s>")
yield gr.update(value=""), chat_history yield gr.update(value=""), chat_history
def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int,
img_file_path: str = None):
"""re-generate the answer of last round's query """re-generate the answer of last round's query
Args: Args:
@ -108,10 +172,11 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution. temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate. max_dec_len (int): The maximum numbers of tokens to generate.
img_file_path (str): Image filepath.
Yields: Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
""" """
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!" assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template # apply chat template
model_input = [] model_input = []
@ -120,6 +185,11 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_
model_input.append({"role": "assistant", "content": a}) model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": chat_history[-1][0]}) model_input.append({"role": "user", "content": chat_history[-1][0]})
# yield model generation # yield model generation
if check_model_v():
chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path)
yield gr.update(value=""), chat_history
return
for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len): for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len):
chat_history[-1][1] = answer.strip("</s>") chat_history[-1][1] = answer.strip("</s>")
yield gr.update(value=""), chat_history yield gr.update(value=""), chat_history
@ -157,6 +227,8 @@ with gr.Blocks(theme="soft") as demo:
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature") temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, step=0.1, label="repetition_penalty") repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, step=0.1, label="repetition_penalty")
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len") max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len")
img_file_path = gr.Image(label="upload image", type='filepath', show_label=False)
with gr.Column(scale=5): with gr.Column(scale=5):
chatbot = gr.Chatbot(bubble_full_width=False, height=400) chatbot = gr.Chatbot(bubble_full_width=False, height=400)
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8) user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
@ -166,8 +238,12 @@ with gr.Blocks(theme="soft") as demo:
regen = gr.Button("Regenerate") regen = gr.Button("Regenerate")
reverse = gr.Button("Reverse") reverse = gr.Button("Reverse")
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot]) img_file_path.change(check_model_v, inputs=[img_file_path], outputs=[])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty,
max_dec_len, img_file_path], outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty,
max_dec_len, img_file_path], outputs=[user_input, chatbot])
clear.click(clear_history, inputs=[], outputs=[chatbot]) clear.click(clear_history, inputs=[], outputs=[chatbot])
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot]) reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])