diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 3b04cf4..4eaa184 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -3,6 +3,7 @@ import argparse import gradio as gr import torch from threading import Thread +from PIL import Image from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -32,8 +33,31 @@ else: # init model and tokenizer path = args.model_path -tokenizer = AutoTokenizer.from_pretrained(path) -model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="cuda:0", trust_remote_code=True) + +model_architectures = model.config.architectures[0] + + +def check_model_v(img_file_path: str = None): + ''' + check model is MiniCPMV + Args: + img_file_path (str): Image filepath + + Returns: + Ture if model is MiniCPMV else False + ''' + if "MiniCPMV" in model_architectures: + return True + if isinstance(img_file_path, str): + gr.Warning('Only MiniCPMV model can support Image') + return False + + +if check_model_v(): + model = model.to(dtype=torch.bfloat16) + # init gradio demo host and port server_name = args.server_name @@ -73,8 +97,40 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f yield answer[4 + len(inputs):] -def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, - max_dec_len: int): +def hf_v_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int, + img_file_path: str): + """generate model output with huggingface api + + Args: + query (str): actual model input. + top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature (float): Strictly positive float value used to modulate the logits distribution. + max_dec_len (int): The maximum numbers of tokens to generate. + img_file_path (str): Image filepath. + + Yields: + str: real-time generation results of hf model + """ + assert isinstance(img_file_path, str), 'Image must not be empty' + img = Image.open(img_file_path).convert('RGB') + + generation_kwargs = dict( + image=img, + msgs=dialog, + context=None, + tokenizer=tokenizer, + sampling=True, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + max_new_tokens=max_dec_len + ) + res, context, _ = model.chat(**generation_kwargs) + return res + + +def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int, + img_file_path: str = None): """generate after hitting "submit" button Args: @@ -83,6 +139,7 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. temperature (float): strictly positive float value used to modulate the logits distribution. max_dec_len (int): The maximum numbers of tokens to generate. + img_file_path (str): Image filepath. Yields: List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round. @@ -96,12 +153,18 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r model_input.append({"role": "user", "content": query}) # yield model generation chat_history.append([query, ""]) + if check_model_v(): + chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path) + yield gr.update(value=""), chat_history + return + for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len): chat_history[-1][1] = answer.strip("") yield gr.update(value=""), chat_history -def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): +def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int, + img_file_path: str = None): """re-generate the answer of last round's query Args: @@ -109,6 +172,7 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_ top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. temperature (float): strictly positive float value used to modulate the logits distribution. max_dec_len (int): The maximum numbers of tokens to generate. + img_file_path (str): Image filepath. Yields: List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history @@ -121,6 +185,11 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_ model_input.append({"role": "assistant", "content": a}) model_input.append({"role": "user", "content": chat_history[-1][0]}) # yield model generation + if check_model_v(): + chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path) + yield gr.update(value=""), chat_history + return + for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len): chat_history[-1][1] = answer.strip("") yield gr.update(value=""), chat_history @@ -158,6 +227,8 @@ with gr.Blocks(theme="soft") as demo: temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature") repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, step=0.1, label="repetition_penalty") max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len") + img_file_path = gr.Image(label="upload image", type='filepath', show_label=False) + with gr.Column(scale=5): chatbot = gr.Chatbot(bubble_full_width=False, height=400) user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8) @@ -167,10 +238,12 @@ with gr.Blocks(theme="soft") as demo: regen = gr.Button("Regenerate") reverse = gr.Button("Reverse") - submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], - outputs=[user_input, chatbot]) - regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], - outputs=[user_input, chatbot]) + img_file_path.change(check_model_v, inputs=[img_file_path], outputs=[]) + + submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, + max_dec_len, img_file_path], outputs=[user_input, chatbot]) + regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, + max_dec_len, img_file_path], outputs=[user_input, chatbot]) clear.click(clear_history, inputs=[], outputs=[chatbot]) reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot])