更新多模态 语音 视觉的内容

1. 更新本地模型语音 视觉多模态功能并设置了对应工具
This commit is contained in:
zR 2023-12-10 13:50:02 +08:00 committed by liunux4odoo
parent bc225bf9f5
commit 03891cc27a
13 changed files with 194 additions and 19 deletions

View File

@ -51,9 +51,8 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
if "tool_call" in text:
action_end = text.find("```")
action = text[:action_end].strip()
params_str_start = text.find("(") + 1
params_str_end = text.find(")")
params_str_end = text.rfind(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]

40
server/agent/container.py Normal file
View File

@ -0,0 +1,40 @@
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
from configs import TOOL_CONFIG
import torch
class ModelContainer:
def __init__(self):
self.model = None
self.metadata = None
self.metadata_response = None
self.vision_model = None
self.vision_tokenizer = None
self.audio_tokenizer = None
self.audio_model = None
if TOOL_CONFIG["vqa_processor"]["use"]:
self.vision_tokenizer = LlamaTokenizer.from_pretrained(
TOOL_CONFIG["vqa_processor"]["tokenizer_path"],
trust_remote_code=True)
self.vision_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(TOOL_CONFIG["vqa_processor"]["device"]).eval()
if TOOL_CONFIG["aqa_processor"]["use"]:
self.audio_tokenizer = AutoTokenizer.from_pretrained(
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
trust_remote_code=True)
self.audio_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).to(TOOL_CONFIG["aqa_processor"]["device"]).eval()
container = ModelContainer()

View File

@ -6,3 +6,6 @@ from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput
from .search_youtube import search_youtube, YoutubeInput
from .arxiv import arxiv, ArxivInput
from .vision_factory import *
from .audio_factory import *

View File

@ -0,0 +1 @@
from .aqa import aqa_processor, AQAInput

View File

@ -0,0 +1,31 @@
import base64
import os
from pydantic import BaseModel, Field
def save_base64_audio(base64_audio, file_path):
audio_data = base64.b64decode(base64_audio)
with open(file_path, 'wb') as audio_file:
audio_file.write(audio_data)
def aqa_run(model, tokenizer, query):
query = tokenizer.from_list_format([query])
response, history = model.chat(tokenizer, query=query, history=None)
print(response)
return response
def aqa_processor(query: str):
from server.agent.container import container
if container.metadata["audios"]:
file_path = "temp_audio.mp3"
save_base64_audio(container.metadata["audios"][0], file_path)
query_input = {
"audio": file_path,
"text": query,
}
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
else:
return "No Audio, Please Try Again"
class AQAInput(BaseModel):
query: str = Field(description="The question of the image in English")

View File

@ -55,5 +55,16 @@ all_tools = [
description=template_knowledge,
args_schema=SearchKnowledgeInput,
),
StructuredTool.from_function(
func=vqa_processor,
name="vqa_processor",
description="use this tool to get answer for image question",
args_schema=VQAInput,
),
StructuredTool.from_function(
func=aqa_processor,
name="aqa_processor",
description="use this tool to get answer for audio question",
args_schema=AQAInput,
)
]

View File

@ -0,0 +1 @@
from .vqa import vqa_processor,VQAInput

View File

@ -0,0 +1,64 @@
"""
Method Use cogagent to generate response for a given image and query.
"""
import base64
from io import BytesIO
import torch
from PIL import Image
from pydantic import BaseModel, Field
from configs import TOOL_CONFIG
def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9,
temperature=1.0):
"""
Args:
image_path (str): path to the image
query (str): query
model (torch.nn.Module): model
history (list): history
image (torch.Tensor): image
max_length (int): max length
top_p (float): top p
temperature (float): temperature
top_k (int): top k
"""
image = Image.open(BytesIO(base64.b64decode(image_base_64)))
inputs = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to(device),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(device),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to(device),
'images': [[inputs['images'][0].to(device).to(torch.bfloat16)]],
'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[
'cross_images'] else None,
}
gen_kwargs = {"max_length": max_length,
"temperature": temperature,
"top_p": top_p,
"do_sample": False}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0])
response = response.split("</s>")[0]
return response
def vqa_processor(query: str):
from server.agent.container import container
tool_config = TOOL_CONFIG["vqa_processor"]
# model, tokenizer = load_model(model_path=tool_config["model_path"],
# tokenizer_path=tool_config["tokenizer_path"],
# device=tool_config["device"])
if container.metadata["images"]:
image_base64 = container.metadata["images"][0]
return vqa_run(model=container.vision_model, tokenizer=container.vision_tokenizer, query=query, image_base_64=image_base64,
device=tool_config["device"])
else:
return "No Image, Please Try Again"
class VQAInput(BaseModel):
query: str = Field(description="The question of the image in English")

View File

@ -19,7 +19,7 @@ from server.utils import get_prompt_template
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
from server.agent.container import container
def create_models_from_config(configs, callbacks):
if configs is None:
@ -46,6 +46,7 @@ def create_models_from_config(configs, callbacks):
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
memory = None
chat_prompt = None
container.metadata = metadata
if history:
history = [History.from_data(h) for h in history]
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
@ -64,6 +65,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
| models["preprocess_model"]
| StrOutputParser()
)
if "chatglm3" in models["action_model"].model_name.lower():
agent_executor = initialize_glm3_agent(
llm=models["action_model"],
@ -91,7 +93,8 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
)
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
else:
full_chain = ({"input": lambda x: x["input"]} | chain)
# full_chain = ({"input": lambda x: x["input"]} | chain)
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
return full_chain
@ -111,11 +114,13 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
tool_config: Dict = Body({}, description="工具配置"),
):
async def chat_iterator() -> AsyncIterable[str]:
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) if conversation_id else None
message_id = add_message_to_db(chat_type="llm_chat", query=query,
conversation_id=conversation_id) if conversation_id else None
callback = CustomAsyncIteratorCallbackHandler()
callbacks = [callback]
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
tools = [tool for tool in all_tools if tool.name in tool_config]
# 构建完整的Chain
full_chain = create_models_chains(prompts=prompts,
models=models,

View File

@ -363,8 +363,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
print(config, "*******")
breakpoint()
if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True
if provider := config.get("provider"):
@ -381,7 +379,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
if path and os.path.isdir(path):
config["model_path_exists"] = True
config["device"] = llm_device(config.get("device"))
breakpoint()
return config

View File

@ -143,7 +143,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.engine_use_ray = False
args.disable_log_requests = False
# 0.2.1 vllm后要加的参数, 但是这里不需要
# 0.2.2 vllm后要加的参数, 但是这里不需要
args.max_model_len = None
args.revision = None
args.quantization = None

View File

@ -1,3 +1,5 @@
import base64
import streamlit as st
from webui_pages.dialogue.utils import process_files
@ -12,6 +14,7 @@ from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
from server.knowledge_base.utils import LOADER_DICT
import uuid
from typing import List, Dict
from PIL import Image
chat_box = ChatBox(
assistant_avatar=os.path.join(
@ -202,11 +205,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if llm_model is not None:
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
files = st.file_uploader("上传附件",
type=[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True)
files_upload = process_files(files=files) if files else None
print(len(files_upload)) if files_upload else None
# files = st.file_uploader("上传附件",accept_multiple_files=False)
# type=[i for ls in LOADER_DICT.values() for i in ls],)
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
# print(len(files_upload["audios"])) if files_upload else None
# if dialogue_mode == "文件对话":
# with st.expander("文件对话配置", True):
@ -248,11 +252,26 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
history = get_messages_history(
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
chat_box.user_say(prompt)
if files_upload["images"]:
# 显示第一个上传的图像
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
unsafe_allow_html=True)
elif files_upload["videos"]:
# 显示第一个上传的视频
st.markdown(
f'<video width="320" height="240" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
unsafe_allow_html=True)
elif files_upload["audios"]:
# 播放第一个上传的音频
st.markdown(
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
unsafe_allow_html=True)
chat_box.ai_say("正在思考...")
text = ""
message_id = ""
element_index = 0
for d in api.chat_chat(query=prompt,
metadata=files_upload,
history=history,
@ -287,6 +306,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
element_index = 0
text = chunk
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
chat_box.show_feedback(**feedback_kwargs,
key=message_id,
on_submit=on_feedback,

View File

@ -1,8 +1,8 @@
import streamlit as st
import base64
import os
from io import BytesIO
def encode_file_to_base64(file):
# 将文件内容转换为 Base64 编码
buffer = BytesIO()
@ -11,8 +11,7 @@ def encode_file_to_base64(file):
def process_files(files):
result = {"videos": [], "images": []}
result = {"videos": [], "images": [], "audios": []}
for file in files:
file_extension = os.path.splitext(file.name)[1].lower()
@ -25,5 +24,9 @@ def process_files(files):
# 图像文件处理
image_base64 = encode_file_to_base64(file)
result["images"].append(image_base64)
elif file_extension in ['.mp3', '.wav', '.ogg', '.flac']:
# 音频文件处理
audio_base64 = encode_file_to_base64(file)
result["audios"].append(audio_base64)
return result
return result