From ce391fb0991ae272098b2ab6216e98dc7a06cf71 Mon Sep 17 00:00:00 2001 From: winter <2453101190@qq.com> Date: Wed, 27 Mar 2024 16:30:09 +0800 Subject: [PATCH 1/2] add MiniCPMV in hf_demo --- demo/hf_based_demo.py | 102 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 89 insertions(+), 13 deletions(-) diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 6677d9b..1441340 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -1,13 +1,13 @@ -from typing import Dict from typing import List -from typing import Tuple import argparse import gradio as gr import torch from threading import Thread + +from PIL import Image from transformers import ( - AutoModelForCausalLM, + AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer ) @@ -23,7 +23,7 @@ args = parser.parse_args() # init model torch dtype torch_dtype = args.torch_dtype -if torch_dtype =="" or torch_dtype == "bfloat16": +if torch_dtype == "" or torch_dtype == "bfloat16": torch_dtype = torch.bfloat16 elif torch_dtype == "float32": torch_dtype = torch.float32 @@ -32,12 +32,36 @@ else: # init model and tokenizer path = args.model_path -tokenizer = AutoTokenizer.from_pretrained(path) +tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True) +model_architectures = model.config.architectures[0] + + +def check_model_v(img_file_path: str = None): + ''' + check model is MiniCPMV + Args: + img_file_path (str): Image filepath + + Returns: + Ture if model is MiniCPMV else False + ''' + if model_architectures == "MiniCPMV": + return True + if isinstance(img_file_path, str): + gr.Warning('Only MiniCPMV model can support Image') + return False + + +if check_model_v(): + model = model.to(dtype=torch.bfloat16) + + # init gradio demo host and port -server_name=args.server_name -server_port=args.server_port +server_name = args.server_name +server_port = args.server_port + def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): """generate model output with huggingface api @@ -73,7 +97,40 @@ def hf_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: f yield answer[4 + len(inputs):] -def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): +def hf_v_gen(dialog: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int, + img_file_path: str): + """generate model output with huggingface api + + Args: + query (str): actual model input. + top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature (float): Strictly positive float value used to modulate the logits distribution. + max_dec_len (int): The maximum numbers of tokens to generate. + img_file_path (str): Image filepath. + + Yields: + str: real-time generation results of hf model + """ + assert isinstance(img_file_path, str), 'Image must not be empty' + img = Image.open(img_file_path).convert('RGB') + + generation_kwargs = dict( + image=img, + msgs=dialog, + context=None, + tokenizer=tokenizer, + sampling=True, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + max_new_tokens=max_dec_len + ) + res, context, _ = model.chat(**generation_kwargs) + return res + + +def generate(chat_history: List, query: str, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int, + img_file_path: str = None): """generate after hitting "submit" button Args: @@ -82,10 +139,11 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. temperature (float): strictly positive float value used to modulate the logits distribution. max_dec_len (int): The maximum numbers of tokens to generate. + img_file_path (str): Image filepath. Yields: List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round. - """ + """ assert query != "", "Input must not be empty!!!" # apply chat template model_input = [] @@ -95,12 +153,18 @@ def generate(chat_history: List, query: str, top_p: float, temperature: float, r model_input.append({"role": "user", "content": query}) # yield model generation chat_history.append([query, ""]) + if check_model_v(): + chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path) + yield gr.update(value=""), chat_history + return + for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len): chat_history[-1][1] = answer.strip("") yield gr.update(value=""), chat_history -def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int): +def regenerate(chat_history: List, top_p: float, temperature: float, repetition_penalty: float, max_dec_len: int, + img_file_path: str = None): """re-generate the answer of last round's query Args: @@ -108,10 +172,11 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_ top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. temperature (float): strictly positive float value used to modulate the logits distribution. max_dec_len (int): The maximum numbers of tokens to generate. + img_file_path (str): Image filepath. Yields: List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history - """ + """ assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!" # apply chat template model_input = [] @@ -120,6 +185,11 @@ def regenerate(chat_history: List, top_p: float, temperature: float, repetition_ model_input.append({"role": "assistant", "content": a}) model_input.append({"role": "user", "content": chat_history[-1][0]}) # yield model generation + if check_model_v(): + chat_history[-1][1] = hf_v_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len, img_file_path) + yield gr.update(value=""), chat_history + return + for answer in hf_gen(model_input, top_p, temperature, repetition_penalty, max_dec_len): chat_history[-1][1] = answer.strip("") yield gr.update(value=""), chat_history @@ -157,6 +227,8 @@ with gr.Blocks(theme="soft") as demo: temperature = gr.Slider(0.1, 2.0, value=0.5, step=0.1, label="temperature") repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, step=0.1, label="repetition_penalty") max_dec_len = gr.Slider(1, 1024, value=1024, step=1, label="max_dec_len") + img_file_path = gr.Image(label="upload image", type='filepath', show_label=False) + with gr.Column(scale=5): chatbot = gr.Chatbot(bubble_full_width=False, height=400) user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=8) @@ -166,8 +238,12 @@ with gr.Blocks(theme="soft") as demo: regen = gr.Button("Regenerate") reverse = gr.Button("Reverse") - submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot]) - regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot]) + img_file_path.change(check_model_v, inputs=[img_file_path], outputs=[]) + + submit.click(generate, inputs=[chatbot, user_input, top_p, temperature, repetition_penalty, + max_dec_len, img_file_path], outputs=[user_input, chatbot]) + regen.click(regenerate, inputs=[chatbot, top_p, temperature, repetition_penalty, + max_dec_len, img_file_path], outputs=[user_input, chatbot]) clear.click(clear_history, inputs=[], outputs=[chatbot]) reverse.click(reverse_last_round, inputs=[chatbot], outputs=[chatbot]) From a77a34759fc974b19feba5f2e0af9400a3830f20 Mon Sep 17 00:00:00 2001 From: winter <2453101190@qq.com> Date: Thu, 4 Jul 2024 19:54:18 +0800 Subject: [PATCH 2/2] use default gpu 0 --- demo/hf_based_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo/hf_based_demo.py b/demo/hf_based_demo.py index 75df156..4eaa184 100644 --- a/demo/hf_based_demo.py +++ b/demo/hf_based_demo.py @@ -34,7 +34,7 @@ else: # init model and tokenizer path = args.model_path tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) -model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="auto", trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch_dtype, device_map="cuda:0", trust_remote_code=True) model_architectures = model.config.architectures[0]