mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 16:10:18 +08:00
更新多模态 语音 视觉的内容
1. 更新本地模型语音 视觉多模态功能并设置了对应工具
This commit is contained in:
parent
bc225bf9f5
commit
03891cc27a
@ -51,9 +51,8 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|||||||
if "tool_call" in text:
|
if "tool_call" in text:
|
||||||
action_end = text.find("```")
|
action_end = text.find("```")
|
||||||
action = text[:action_end].strip()
|
action = text[:action_end].strip()
|
||||||
|
|
||||||
params_str_start = text.find("(") + 1
|
params_str_start = text.find("(") + 1
|
||||||
params_str_end = text.find(")")
|
params_str_end = text.rfind(")")
|
||||||
params_str = text[params_str_start:params_str_end]
|
params_str = text[params_str_start:params_str_end]
|
||||||
|
|
||||||
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
||||||
|
|||||||
40
server/agent/container.py
Normal file
40
server/agent/container.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from configs import TOOL_CONFIG
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContainer:
|
||||||
|
def __init__(self):
|
||||||
|
self.model = None
|
||||||
|
self.metadata = None
|
||||||
|
|
||||||
|
self.metadata_response = None
|
||||||
|
|
||||||
|
self.vision_model = None
|
||||||
|
self.vision_tokenizer = None
|
||||||
|
self.audio_tokenizer = None
|
||||||
|
self.audio_model = None
|
||||||
|
|
||||||
|
if TOOL_CONFIG["vqa_processor"]["use"]:
|
||||||
|
self.vision_tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
|
TOOL_CONFIG["vqa_processor"]["tokenizer_path"],
|
||||||
|
trust_remote_code=True)
|
||||||
|
self.vision_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=TOOL_CONFIG["vqa_processor"]["model_path"],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
trust_remote_code=True
|
||||||
|
).to(TOOL_CONFIG["vqa_processor"]["device"]).eval()
|
||||||
|
|
||||||
|
if TOOL_CONFIG["aqa_processor"]["use"]:
|
||||||
|
self.audio_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
|
||||||
|
trust_remote_code=True)
|
||||||
|
self.audio_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
trust_remote_code=True).to(TOOL_CONFIG["aqa_processor"]["device"]).eval()
|
||||||
|
|
||||||
|
|
||||||
|
container = ModelContainer()
|
||||||
@ -6,3 +6,6 @@ from .search_internet import search_internet, SearchInternetInput
|
|||||||
from .wolfram import wolfram, WolframInput
|
from .wolfram import wolfram, WolframInput
|
||||||
from .search_youtube import search_youtube, YoutubeInput
|
from .search_youtube import search_youtube, YoutubeInput
|
||||||
from .arxiv import arxiv, ArxivInput
|
from .arxiv import arxiv, ArxivInput
|
||||||
|
|
||||||
|
from .vision_factory import *
|
||||||
|
from .audio_factory import *
|
||||||
|
|||||||
1
server/agent/tools_factory/audio_factory/__init__.py
Normal file
1
server/agent/tools_factory/audio_factory/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .aqa import aqa_processor, AQAInput
|
||||||
31
server/agent/tools_factory/audio_factory/aqa.py
Normal file
31
server/agent/tools_factory/audio_factory/aqa.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
def save_base64_audio(base64_audio, file_path):
|
||||||
|
audio_data = base64.b64decode(base64_audio)
|
||||||
|
with open(file_path, 'wb') as audio_file:
|
||||||
|
audio_file.write(audio_data)
|
||||||
|
|
||||||
|
def aqa_run(model, tokenizer, query):
|
||||||
|
query = tokenizer.from_list_format([query])
|
||||||
|
response, history = model.chat(tokenizer, query=query, history=None)
|
||||||
|
print(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def aqa_processor(query: str):
|
||||||
|
from server.agent.container import container
|
||||||
|
if container.metadata["audios"]:
|
||||||
|
file_path = "temp_audio.mp3"
|
||||||
|
save_base64_audio(container.metadata["audios"][0], file_path)
|
||||||
|
query_input = {
|
||||||
|
"audio": file_path,
|
||||||
|
"text": query,
|
||||||
|
}
|
||||||
|
return aqa_run(tokenizer=container.audio_tokenizer, query=query_input, model=container.audio_model)
|
||||||
|
else:
|
||||||
|
return "No Audio, Please Try Again"
|
||||||
|
|
||||||
|
class AQAInput(BaseModel):
|
||||||
|
query: str = Field(description="The question of the image in English")
|
||||||
@ -55,5 +55,16 @@ all_tools = [
|
|||||||
description=template_knowledge,
|
description=template_knowledge,
|
||||||
args_schema=SearchKnowledgeInput,
|
args_schema=SearchKnowledgeInput,
|
||||||
),
|
),
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=vqa_processor,
|
||||||
|
name="vqa_processor",
|
||||||
|
description="use this tool to get answer for image question",
|
||||||
|
args_schema=VQAInput,
|
||||||
|
),
|
||||||
|
StructuredTool.from_function(
|
||||||
|
func=aqa_processor,
|
||||||
|
name="aqa_processor",
|
||||||
|
description="use this tool to get answer for audio question",
|
||||||
|
args_schema=AQAInput,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
1
server/agent/tools_factory/vision_factory/__init__.py
Normal file
1
server/agent/tools_factory/vision_factory/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .vqa import vqa_processor,VQAInput
|
||||||
64
server/agent/tools_factory/vision_factory/vqa.py
Normal file
64
server/agent/tools_factory/vision_factory/vqa.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
Method Use cogagent to generate response for a given image and query.
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from configs import TOOL_CONFIG
|
||||||
|
|
||||||
|
def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9,
|
||||||
|
temperature=1.0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
image_path (str): path to the image
|
||||||
|
query (str): query
|
||||||
|
model (torch.nn.Module): model
|
||||||
|
history (list): history
|
||||||
|
image (torch.Tensor): image
|
||||||
|
max_length (int): max length
|
||||||
|
top_p (float): top p
|
||||||
|
temperature (float): temperature
|
||||||
|
top_k (int): top k
|
||||||
|
"""
|
||||||
|
image = Image.open(BytesIO(base64.b64decode(image_base_64)))
|
||||||
|
|
||||||
|
inputs = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
|
||||||
|
inputs = {
|
||||||
|
'input_ids': inputs['input_ids'].unsqueeze(0).to(device),
|
||||||
|
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(device),
|
||||||
|
'attention_mask': inputs['attention_mask'].unsqueeze(0).to(device),
|
||||||
|
'images': [[inputs['images'][0].to(device).to(torch.bfloat16)]],
|
||||||
|
'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[
|
||||||
|
'cross_images'] else None,
|
||||||
|
}
|
||||||
|
gen_kwargs = {"max_length": max_length,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"do_sample": False}
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(**inputs, **gen_kwargs)
|
||||||
|
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
||||||
|
response = tokenizer.decode(outputs[0])
|
||||||
|
response = response.split("</s>")[0]
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def vqa_processor(query: str):
|
||||||
|
from server.agent.container import container
|
||||||
|
tool_config = TOOL_CONFIG["vqa_processor"]
|
||||||
|
# model, tokenizer = load_model(model_path=tool_config["model_path"],
|
||||||
|
# tokenizer_path=tool_config["tokenizer_path"],
|
||||||
|
# device=tool_config["device"])
|
||||||
|
if container.metadata["images"]:
|
||||||
|
image_base64 = container.metadata["images"][0]
|
||||||
|
return vqa_run(model=container.vision_model, tokenizer=container.vision_tokenizer, query=query, image_base_64=image_base64,
|
||||||
|
device=tool_config["device"])
|
||||||
|
else:
|
||||||
|
return "No Image, Please Try Again"
|
||||||
|
|
||||||
|
|
||||||
|
class VQAInput(BaseModel):
|
||||||
|
query: str = Field(description="The question of the image in English")
|
||||||
@ -19,7 +19,7 @@ from server.utils import get_prompt_template
|
|||||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||||
from server.db.repository import add_message_to_db
|
from server.db.repository import add_message_to_db
|
||||||
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
||||||
|
from server.agent.container import container
|
||||||
|
|
||||||
def create_models_from_config(configs, callbacks):
|
def create_models_from_config(configs, callbacks):
|
||||||
if configs is None:
|
if configs is None:
|
||||||
@ -46,6 +46,7 @@ def create_models_from_config(configs, callbacks):
|
|||||||
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
|
def create_models_chains(history, history_len, prompts, models, tools, callbacks, conversation_id, metadata):
|
||||||
memory = None
|
memory = None
|
||||||
chat_prompt = None
|
chat_prompt = None
|
||||||
|
container.metadata = metadata
|
||||||
if history:
|
if history:
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||||
@ -64,6 +65,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
| models["preprocess_model"]
|
| models["preprocess_model"]
|
||||||
| StrOutputParser()
|
| StrOutputParser()
|
||||||
)
|
)
|
||||||
|
|
||||||
if "chatglm3" in models["action_model"].model_name.lower():
|
if "chatglm3" in models["action_model"].model_name.lower():
|
||||||
agent_executor = initialize_glm3_agent(
|
agent_executor = initialize_glm3_agent(
|
||||||
llm=models["action_model"],
|
llm=models["action_model"],
|
||||||
@ -91,7 +93,8 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
|||||||
)
|
)
|
||||||
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
full_chain = ({"topic": classifier_chain, "input": lambda x: x["input"]} | branch)
|
||||||
else:
|
else:
|
||||||
full_chain = ({"input": lambda x: x["input"]} | chain)
|
# full_chain = ({"input": lambda x: x["input"]} | chain)
|
||||||
|
full_chain = ({"input": lambda x: x["input"]} | agent_executor)
|
||||||
return full_chain
|
return full_chain
|
||||||
|
|
||||||
|
|
||||||
@ -111,11 +114,13 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||||||
tool_config: Dict = Body({}, description="工具配置"),
|
tool_config: Dict = Body({}, description="工具配置"),
|
||||||
):
|
):
|
||||||
async def chat_iterator() -> AsyncIterable[str]:
|
async def chat_iterator() -> AsyncIterable[str]:
|
||||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) if conversation_id else None
|
message_id = add_message_to_db(chat_type="llm_chat", query=query,
|
||||||
|
conversation_id=conversation_id) if conversation_id else None
|
||||||
callback = CustomAsyncIteratorCallbackHandler()
|
callback = CustomAsyncIteratorCallbackHandler()
|
||||||
callbacks = [callback]
|
callbacks = [callback]
|
||||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||||
|
|
||||||
# 构建完整的Chain
|
# 构建完整的Chain
|
||||||
full_chain = create_models_chains(prompts=prompts,
|
full_chain = create_models_chains(prompts=prompts,
|
||||||
models=models,
|
models=models,
|
||||||
|
|||||||
@ -363,8 +363,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||||
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
||||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
||||||
print(config, "*******")
|
|
||||||
breakpoint()
|
|
||||||
if model_name in ONLINE_LLM_MODEL:
|
if model_name in ONLINE_LLM_MODEL:
|
||||||
config["online_api"] = True
|
config["online_api"] = True
|
||||||
if provider := config.get("provider"):
|
if provider := config.get("provider"):
|
||||||
@ -381,7 +379,6 @@ def get_model_worker_config(model_name: str = None) -> dict:
|
|||||||
if path and os.path.isdir(path):
|
if path and os.path.isdir(path):
|
||||||
config["model_path_exists"] = True
|
config["model_path_exists"] = True
|
||||||
config["device"] = llm_device(config.get("device"))
|
config["device"] = llm_device(config.get("device"))
|
||||||
breakpoint()
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -143,7 +143,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
args.engine_use_ray = False
|
args.engine_use_ray = False
|
||||||
args.disable_log_requests = False
|
args.disable_log_requests = False
|
||||||
|
|
||||||
# 0.2.1 vllm后要加的参数, 但是这里不需要
|
# 0.2.2 vllm后要加的参数, 但是这里不需要
|
||||||
args.max_model_len = None
|
args.max_model_len = None
|
||||||
args.revision = None
|
args.revision = None
|
||||||
args.quantization = None
|
args.quantization = None
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from webui_pages.dialogue.utils import process_files
|
from webui_pages.dialogue.utils import process_files
|
||||||
@ -12,6 +14,7 @@ from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
|
|||||||
from server.knowledge_base.utils import LOADER_DICT
|
from server.knowledge_base.utils import LOADER_DICT
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
@ -202,11 +205,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
if llm_model is not None:
|
if llm_model is not None:
|
||||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||||
|
|
||||||
files = st.file_uploader("上传附件",
|
# files = st.file_uploader("上传附件",accept_multiple_files=False)
|
||||||
type=[i for ls in LOADER_DICT.values() for i in ls],
|
# type=[i for ls in LOADER_DICT.values() for i in ls],)
|
||||||
accept_multiple_files=True)
|
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||||
files_upload = process_files(files=files) if files else None
|
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||||
print(len(files_upload)) if files_upload else None
|
|
||||||
|
# print(len(files_upload["audios"])) if files_upload else None
|
||||||
|
|
||||||
# if dialogue_mode == "文件对话":
|
# if dialogue_mode == "文件对话":
|
||||||
# with st.expander("文件对话配置", True):
|
# with st.expander("文件对话配置", True):
|
||||||
@ -248,11 +252,26 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
history = get_messages_history(
|
history = get_messages_history(
|
||||||
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
||||||
chat_box.user_say(prompt)
|
chat_box.user_say(prompt)
|
||||||
|
if files_upload["images"]:
|
||||||
|
# 显示第一个上传的图像
|
||||||
|
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
|
||||||
|
unsafe_allow_html=True)
|
||||||
|
elif files_upload["videos"]:
|
||||||
|
# 显示第一个上传的视频
|
||||||
|
st.markdown(
|
||||||
|
f'<video width="320" height="240" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
|
||||||
|
unsafe_allow_html=True)
|
||||||
|
elif files_upload["audios"]:
|
||||||
|
# 播放第一个上传的音频
|
||||||
|
st.markdown(
|
||||||
|
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||||
|
unsafe_allow_html=True)
|
||||||
|
|
||||||
chat_box.ai_say("正在思考...")
|
chat_box.ai_say("正在思考...")
|
||||||
text = ""
|
text = ""
|
||||||
message_id = ""
|
message_id = ""
|
||||||
element_index = 0
|
element_index = 0
|
||||||
|
|
||||||
for d in api.chat_chat(query=prompt,
|
for d in api.chat_chat(query=prompt,
|
||||||
metadata=files_upload,
|
metadata=files_upload,
|
||||||
history=history,
|
history=history,
|
||||||
@ -287,6 +306,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
element_index = 0
|
element_index = 0
|
||||||
text = chunk
|
text = chunk
|
||||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||||
|
|
||||||
chat_box.show_feedback(**feedback_kwargs,
|
chat_box.show_feedback(**feedback_kwargs,
|
||||||
key=message_id,
|
key=message_id,
|
||||||
on_submit=on_feedback,
|
on_submit=on_feedback,
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
|
import streamlit as st
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
def encode_file_to_base64(file):
|
def encode_file_to_base64(file):
|
||||||
# 将文件内容转换为 Base64 编码
|
# 将文件内容转换为 Base64 编码
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
@ -11,8 +11,7 @@ def encode_file_to_base64(file):
|
|||||||
|
|
||||||
|
|
||||||
def process_files(files):
|
def process_files(files):
|
||||||
result = {"videos": [], "images": []}
|
result = {"videos": [], "images": [], "audios": []}
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
file_extension = os.path.splitext(file.name)[1].lower()
|
file_extension = os.path.splitext(file.name)[1].lower()
|
||||||
|
|
||||||
@ -25,5 +24,9 @@ def process_files(files):
|
|||||||
# 图像文件处理
|
# 图像文件处理
|
||||||
image_base64 = encode_file_to_base64(file)
|
image_base64 = encode_file_to_base64(file)
|
||||||
result["images"].append(image_base64)
|
result["images"].append(image_base64)
|
||||||
|
elif file_extension in ['.mp3', '.wav', '.ogg', '.flac']:
|
||||||
|
# 音频文件处理
|
||||||
|
audio_base64 = encode_file_to_base64(file)
|
||||||
|
result["audios"].append(audio_base64)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
Loading…
x
Reference in New Issue
Block a user