mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-02 20:53:13 +08:00
更新多模态 语音 视觉的内容
1. 更新本地模型语音 视觉多模态功能并设置了对应工具
This commit is contained in:
parent
bc225bf9f5
commit
03891cc27a
@ -51,9 +51,8 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
||||
if "tool_call" in text:
|
||||
action_end = text.find("```")
|
||||
action = text[:action_end].strip()
|
||||
|
||||
params_str_start = text.find("(") + 1
|
||||
params_str_end = text.find(")")
|
||||
params_str_end = text.rfind(")")
|
||||
params_str = text[params_str_start:params_str_end]
|
||||
|
||||
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
||||
|
||||
40
server/agent/container.py
Normal file
40
server/agent/container.py
Normal file
@ -0,0 +1,40 @@
|
||||
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
|
||||
from configs import TOOL_CONFIG
|
||||
import torch
|
||||
|
||||
|
||||
class ModelContainer:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.metadata = None
|
||||
|
||||
self.metadata_response = None
|
||||
|
||||
self.vision_model = None
|
||||
self.vision_tokenizer = None
|
||||
self.audio_tokenizer = None
|
||||
self.audio_model = None
|
||||
|
||||
if TOOL_CONFIG["vqa_processor"]["use"]:
|
||||
self.vision_tokenizer = LlamaTokenizer.from_pretrained(
|
||||
TOOL_CONFIG["vqa_processor"]["tokenizer_path"],
|
||||
trust_remote_code=True)
|
||||
self.vision_model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"],
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True
|
||||
).to(TOOL_CONFIG["vqa_processor"]["device"]).eval()
|
||||
|
||||
if TOOL_CONFIG["aqa_processor"]["use"]:
|
||||
self.audio_tokenizer = AutoTokenizer.from_pretrained(
|
||||
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
|
||||
trust_remote_code=True)
|
||||
self.audio_model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True).to(TOOL_CONFIG["aqa_processor"]["device"]).eval()
|
||||
|
||||
|
||||
container = ModelContainer()
|
||||
@ -6,3 +6,6 @@ from .search_internet import search_internet, SearchInternetInput
|
||||
from .wolfram import wolfram, WolframInput
|
||||
from .search_youtube import search_youtube, YoutubeInput
|
||||
from .arxiv import arxiv, ArxivInput
|
||||
|
||||
from .vision_factory import *
|
||||
from .audio_factory import *
|
||||
|
||||
1
server/agent/tools_factory/audio_factory/__init__.py
Normal file
1
server/agent/tools_factory/audio_factory/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .aqa import aqa_processor, AQAInput
|
||||
31
server/agent/tools_factory/audio_factory/aqa.py
Normal file
31
server/agent/tools_factory/audio_factory/aqa.py
Normal file
@ -0,0 +1,31 @@
|
||||
import base64
|
||||
import os
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
def save_base64_audio(base64_audio, file_path):
|
||||
audio_data = base64.b64decode(base64_audio)
|
||||
with open(file_path, 'wb') as audio_file:
|
||||
audio_file.write(audio_data)
|
||||
|
||||
def aqa_run(model, tokenizer, query):
|
||||
query = tokenizer.from_list_format([query])
|
||||
response, history = model.chat(tokenizer, query=query, history=None)
|
||||
print(response)
|
||||
return response
|
||||
|
||||
|
||||
def aqa_processor(query: str):
|
||||
from server.agent.container import container
|
||||
if container.metadata["audios"]:
|
||||
file_path = "temp_audio.mp3"
|
||||
save_base64_audio(container.metadata["audios"][0], file_path)
|
||||
query_input = {
|
||||
"audio": file_path,
|
||||
"text": query,
|
||||
}
|
||||
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
|
||||
else:
|
||||
return "No Audio, Please Try Again"
|
||||
|
||||
class AQAInput(BaseModel):
|
||||
query: str = Field(description="The question of the image in English")
|
||||
@ -55,5 +55,16 @@ all_tools = [
|
||||
description=template_knowledge,
|
||||
args_schema=SearchKnowledgeInput,
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=vqa_processor,
|
||||
name="vqa_processor",
|
||||
description="use this tool to get answer for image question",
|
||||
args_schema=VQAInput,
|
||||
),
|
||||
StructuredTool.from_function(
|
||||
func=aqa_processor,
|
||||
name="aqa_processor",
|
||||
description="use this tool to get answer for audio question",
|
||||
args_schema=AQAInput,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
1
server/agent/tools_factory/vision_factory/__init__.py
Normal file
1
server/agent/tools_factory/vision_factory/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .vqa import vqa_processor,VQAInput
|
||||
64
server/agent/tools_factory/vision_factory/vqa.py
Normal file
64
server/agent/tools_factory/vision_factory/vqa.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
Method Use cogagent to generate response for a given image and query.
|
||||
"""
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from configs import TOOL_CONFIG
|
||||
|
||||
def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9,
|
||||
temperature=1.0):
|
||||
"""
|
||||
Args:
|
||||
image_path (str): path to the image
|
||||
query (str): query
|
||||
model (torch.nn.Module): model
|
||||
history (list): history
|
||||
image (torch.Tensor): image
|
||||
max_length (int): max length
|
||||
top_p (float): top p
|
||||
temperature (float): temperature
|
||||
top_k (int): top k
|
||||
"""
|
||||
image = Image.open(BytesIO(base64.b64decode(image_base_64)))
|
||||
|
||||
inputs = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
|
||||
inputs = {
|
||||
'input_ids': inputs['input_ids'].unsqueeze(0).to(device),
|
||||
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(device),
|
||||
'attention_mask': inputs['attention_mask'].unsqueeze(0).to(device),
|
||||
'images': [[inputs['images'][0].to(device).to(torch.bfloat16)]],
|
||||
'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[
|
||||
'cross_images'] else None,
|
||||
}
|
||||
gen_kwargs = {"max_length": max_length,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": False}
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(**inputs, **gen_kwargs)
|
||||
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
||||
response = tokenizer.decode(outputs[0])
|
||||
response = response.split("</s>")[0]
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def vqa_processor(query: str):
|
||||
from server.agent.container import container
|
||||
tool_config = TOOL_CONFIG["vqa_processor"]
|
||||
# model, tokenizer = load_model(model_path=tool_config["model_path"],
|
||||
# tokenizer_path=tool_config["tokenizer_path"],
|
||||
# device=tool_config["device"])
|
||||
if container.metadata["images"]:
|
||||
image_base64 = container.metadata["images"][0]
|
||||
return vqa_run(model=container.vision_model, tokenizer=container.vision_tokenizer, query=query, image_base_64=image_base64,
|
||||
device=tool_config["device"])
|
||||
else:
|
||||
return "No Image, Please Try Again"
|
||||
|
||||
|
||||
class VQAInput(BaseModel):
|
||||
query: str = Field(description="The question of the image in English")
|
||||
@ -19,7 +19,7 @@ from server.utils import get_prompt_template
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
||||
|
||||
from server.agent.container import container
|
||||
|
||||
def create_models_from_config(configs, callbacks):
|
||||
if configs is None:
|
||||
@ -46,6 +46,7 @@ def create_models_from_config(configs, callbacks):
|
||||
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
|
||||
memory = None
|
||||
chat_prompt = None
|
||||
container.metadata = metadata
|
||||
if history:
|
||||
history = [History.from_data(h) for h in history]
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
@ -64,6 +65,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
| models["preprocess_model"]
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
if "chatglm3" in models["action_model"].model_name.lower():
|
||||
agent_executor = initialize_glm3_agent(
|
||||
llm=models["action_model"],
|
||||
@ -91,7 +93,8 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
)
|
||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||
else:
|
||||
full_chain = ({"input": lambda x: x["input"]} | chain)
|
||||
# full_chain = ({"input": lambda x: x["input"]} | chain)
|
||||
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
|
||||
return full_chain
|
||||
|
||||
|
||||
@ -111,11 +114,13 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
tool_config: Dict = Body({}, description="工具配置"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) if conversation_id else None
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query,
|
||||
conversation_id=conversation_id) if conversation_id else None
|
||||
callback = CustomAsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||
|
||||
# 构建完整的Chain
|
||||
full_chain = create_models_chains(prompts=prompts,
|
||||
models=models,
|
||||
|
||||
@ -363,8 +363,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||
print(config, "*******")
|
||||
breakpoint()
|
||||
if model_name in ONLINE_LLM_MODEL:
|
||||
config["online_api"] = True
|
||||
if provider := config.get("provider"):
|
||||
@ -381,7 +379,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
||||
if path and os.path.isdir(path):
|
||||
config["model_path_exists"] = True
|
||||
config["device"] = llm_device(config.get("device"))
|
||||
breakpoint()
|
||||
return config
|
||||
|
||||
|
||||
|
||||
@ -143,7 +143,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
args.engine_use_ray = False
|
||||
args.disable_log_requests = False
|
||||
|
||||
# 0.2.1 vllm后要加的参数, 但是这里不需要
|
||||
# 0.2.2 vllm后要加的参数, 但是这里不需要
|
||||
args.max_model_len = None
|
||||
args.revision = None
|
||||
args.quantization = None
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import base64
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from webui_pages.dialogue.utils import process_files
|
||||
@ -12,6 +14,7 @@ from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
|
||||
from server.knowledge_base.utils import LOADER_DICT
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from PIL import Image
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
@ -202,11 +205,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
if llm_model is not None:
|
||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||
|
||||
files = st.file_uploader("上传附件",
|
||||
type=[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True)
|
||||
files_upload = process_files(files=files) if files else None
|
||||
print(len(files_upload)) if files_upload else None
|
||||
# files = st.file_uploader("上传附件",accept_multiple_files=False)
|
||||
# type=[i for ls in LOADER_DICT.values() for i in ls],)
|
||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||
|
||||
# print(len(files_upload["audios"])) if files_upload else None
|
||||
|
||||
# if dialogue_mode == "文件对话":
|
||||
# with st.expander("文件对话配置", True):
|
||||
@ -248,11 +252,26 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
history = get_messages_history(
|
||||
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
||||
chat_box.user_say(prompt)
|
||||
if files_upload["images"]:
|
||||
# 显示第一个上传的图像
|
||||
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["videos"]:
|
||||
# 显示第一个上传的视频
|
||||
st.markdown(
|
||||
f'<video width="320" height="240" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["audios"]:
|
||||
# 播放第一个上传的音频
|
||||
st.markdown(
|
||||
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
message_id = ""
|
||||
element_index = 0
|
||||
|
||||
for d in api.chat_chat(query=prompt,
|
||||
metadata=files_upload,
|
||||
history=history,
|
||||
@ -287,6 +306,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
element_index = 0
|
||||
text = chunk
|
||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=message_id,
|
||||
on_submit=on_feedback,
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import streamlit as st
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def encode_file_to_base64(file):
|
||||
# 将文件内容转换为 Base64 编码
|
||||
buffer = BytesIO()
|
||||
@ -11,8 +11,7 @@ def encode_file_to_base64(file):
|
||||
|
||||
|
||||
def process_files(files):
|
||||
result = {"videos": [], "images": []}
|
||||
|
||||
result = {"videos": [], "images": [], "audios": []}
|
||||
for file in files:
|
||||
file_extension = os.path.splitext(file.name)[1].lower()
|
||||
|
||||
@ -25,5 +24,9 @@ def process_files(files):
|
||||
# 图像文件处理
|
||||
image_base64 = encode_file_to_base64(file)
|
||||
result["images"].append(image_base64)
|
||||
elif file_extension in ['.mp3', '.wav', '.ogg', '.flac']:
|
||||
# 音频文件处理
|
||||
audio_base64 = encode_file_to_base64(file)
|
||||
result["audios"].append(audio_base64)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user