add MiniCPMV in hf_demo

This commit is contained in:
winter 2024-03-27 16:30:09 +08:00
parent a1013b1ad2
commit ce391fb099

View File

@ -1,13 +1,13 @@
from typing import Dict
from typing import List
from typing import Tuple
import argparse
import gradio as gr
import torch
from threading import Thread
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer
)
@ -23,7 +23,7 @@ args = parser.parse_args()
# init model torch dtype
torch_dtype = args.torch_dtype
if torch_dtype =="" or torch_dtype == "bfloat16":
if torch_dtype == "" or torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
@ -32,12 +32,36 @@ else:
# init model and tokenizer
path = args.model_path
tokenizer = AutoTokenizer.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True)
model_architectures = model.config.architectures[0]
def check_model_v(img_file_path: str = None):
'''
check model is MiniCPMV
Args:
img_file_path (str): Image filepath
Returns:
Ture if model is MiniCPMV else False
'''
if model_architectures == "MiniCPMV":
return True
if isinstance(img_file_path, str):
gr.Warning('Only MiniCPMV model can support Image')
return False
if check_model_v():
model = model.to(dtype=torch.bfloat16)
# init gradio demo host and port
server_name=args.server_name
server_port=args.server_port
server_name = args.server_name
server_port = args.server_port
def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
"""generate model output with huggingface api
@ -73,7 +97,40 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f
yield answer[4 + len(inputs):]
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
def hf_v_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int,
img_file_path: str):
"""generate model output with huggingface api
Args:
query (str): actual model input.
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): Strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
img_file_path (str): Image filepath.
Yields:
str: real-time generation results of hf model
"""
assert isinstance(img_file_path, str), 'Image must not be empty'
img = Image.open(img_file_path).convert('RGB')
generation_kwargs = dict(
image=img,
msgs=dialog,
context=None,
tokenizer=tokenizer,
sampling=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
max_new_tokens=max_dec_len
)
res, context, _ = model.chat(**generation_kwargs)
return res
def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int,
img_file_path: str = None):
"""generate after hitting "submit" button
Args:
@ -82,10 +139,11 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
img_file_path (str): Image filepath.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round.
"""
"""
assert query != "", "Input must not be empty!!!"
# apply chat template
model_input = []
@ -95,12 +153,18 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r
model_input.append({"role": "user", "content": query})
# yield model generation
chat_history.append([query, ""])
if check_model_v():
chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path)
yield gr.update(value=""), chat_history
return
for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len):
chat_history[-1][1] = answer.strip("</s>")
yield gr.update(value=""), chat_history
def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int):
def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int,
img_file_path: str = None):
"""re-generate the answer of last round's query
Args:
@ -108,10 +172,11 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature (float): strictly positive float value used to modulate the logits distribution.
max_dec_len (int): The maximum numbers of tokens to generate.
img_file_path (str): Image filepath.
Yields:
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history
"""
"""
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!"
# apply chat template
model_input = []
@ -120,6 +185,11 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_
model_input.append({"role": "assistant", "content": a})
model_input.append({"role": "user", "content": chat_history[-1][0]})
# yield model generation
if check_model_v():
chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path)
yield gr.update(value=""), chat_history
return
for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len):
chat_history[-1][1] = answer.strip("</s>")
yield gr.update(value=""), chat_history
@ -157,6 +227,8 @@ with gr.Blocks(theme="soft") as demo:
temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature")
repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, step=0.1, label="repetition_penalty")
max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len")
img_file_path = gr.Image(label="upload image", type='filepath', show_label=False)
with gr.Column(scale=5):
chatbot = gr.Chatbot(bubble_full_width=False, height=400)
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8)
@ -166,8 +238,12 @@ with gr.Blocks(theme="soft") as demo:
regen = gr.Button("Regenerate")
reverse = gr.Button("Reverse")
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot])
img_file_path.change(check_model_v, inputs=[img_file_path], outputs=[])
submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty,
max_dec_len, img_file_path], outputs=[user_input, chatbot])
regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty,
max_dec_len, img_file_path], outputs=[user_input, chatbot])
clear.click(clear_history, inputs=[], outputs=[chatbot])
reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])