mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
支持多模态Grounding
1. 美化了chat的代码 2. 支持视觉工具输出Grounding任务 3. 完善工具调用的流程
This commit is contained in:
parent
03891cc27a
commit
d44ce6ce21
@ -5,4 +5,4 @@ from .server_config import *
|
||||
from .prompt_config import *
|
||||
|
||||
|
||||
VERSION = "v0.2.10"
|
||||
VERSION = "v0.3.0-preview"
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
# 用于批量将configs下的.example文件复制并命名为.py文件
|
||||
"""
|
||||
用于批量将configs下的.example文件复制并命名为.py文件
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# API requirements
|
||||
|
||||
langchain>=0.0.348
|
||||
langchain>=0.0.350
|
||||
langchain-experimental>=0.0.42
|
||||
pydantic==1.10.13
|
||||
fschat==0.2.35
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# API requirements
|
||||
|
||||
langchain>=0.0.346
|
||||
langchain>=0.0.350
|
||||
langchain-experimental>=0.0.42
|
||||
pydantic==1.10.13
|
||||
fschat==0.2.35
|
||||
|
||||
@ -8,8 +8,6 @@ class ModelContainer:
|
||||
self.model = None
|
||||
self.metadata = None
|
||||
|
||||
self.metadata_response = None
|
||||
|
||||
self.vision_model = None
|
||||
self.vision_tokenizer = None
|
||||
self.audio_tokenizer = None
|
||||
@ -29,12 +27,15 @@ class ModelContainer:
|
||||
if TOOL_CONFIG["aqa_processor"]["use"]:
|
||||
self.audio_tokenizer = AutoTokenizer.from_pretrained(
|
||||
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
|
||||
trust_remote_code=True)
|
||||
trust_remote_code=True
|
||||
)
|
||||
self.audio_model = AutoModelForCausalLM.from_pretrained(
|
||||
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True).to(TOOL_CONFIG["aqa_processor"]["device"]).eval()
|
||||
trust_remote_code=True).to(
|
||||
TOOL_CONFIG["aqa_processor"]["device"]
|
||||
).eval()
|
||||
|
||||
|
||||
container = ModelContainer()
|
||||
|
||||
@ -4,9 +4,59 @@ Method Use cogagent to generate response for a given image and query.
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageDraw
|
||||
from pydantic import BaseModel, Field
|
||||
from configs import TOOL_CONFIG
|
||||
import re
|
||||
from server.agent.container import container
|
||||
|
||||
|
||||
def extract_between_markers(text, start_marker, end_marker):
|
||||
"""
|
||||
Extracts and returns the portion of the text that is between 'start_marker' and 'end_marker'.
|
||||
"""
|
||||
start = text.find(start_marker)
|
||||
end = text.find(end_marker, start)
|
||||
|
||||
if start != -1 and end != -1:
|
||||
# Extract and return the text between the markers, without including the markers themselves
|
||||
return text[start + len(start_marker):end].strip()
|
||||
else:
|
||||
return "Text not found between the specified markers"
|
||||
|
||||
|
||||
|
||||
def draw_box_on_existing_image(base64_image, text):
|
||||
"""
|
||||
在已有的Base64编码的图片上根据“Grounded Operation”中的坐标信息绘制矩形框。
|
||||
假设坐标是经过缩放的比例坐标。
|
||||
"""
|
||||
# 解码并打开Base64编码的图片
|
||||
img = Image.open(BytesIO(base64.b64decode(base64_image)))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# 提取“Grounded Operation”后的坐标
|
||||
pattern = r"\[\[(\d+),(\d+),(\d+),(\d+)\]\]"
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
coords = tuple(map(int, match.groups()))
|
||||
scaled_coords = (
|
||||
int(coords[0] * 0.001 * img.width),
|
||||
int(coords[1] * 0.001 * img.height),
|
||||
int(coords[2] * 0.001 * img.width),
|
||||
int(coords[3] * 0.001 * img.height)
|
||||
)
|
||||
draw.rectangle(scaled_coords, outline="red", width=3)
|
||||
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="JPEG")
|
||||
img.save("tmp/image.jpg")
|
||||
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
return img_base64
|
||||
|
||||
|
||||
def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9,
|
||||
temperature=1.0):
|
||||
@ -33,8 +83,9 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
|
||||
'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[
|
||||
'cross_images'] else None,
|
||||
}
|
||||
|
||||
gen_kwargs = {"max_length": max_length,
|
||||
"temperature": temperature,
|
||||
# "temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": False}
|
||||
with torch.no_grad():
|
||||
@ -47,15 +98,23 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
|
||||
|
||||
|
||||
def vqa_processor(query: str):
|
||||
from server.agent.container import container
|
||||
tool_config = TOOL_CONFIG["vqa_processor"]
|
||||
# model, tokenizer = load_model(model_path=tool_config["model_path"],
|
||||
# tokenizer_path=tool_config["tokenizer_path"],
|
||||
# device=tool_config["device"])
|
||||
if container.metadata["images"]:
|
||||
image_base64 = container.metadata["images"][0]
|
||||
return vqa_run(model=container.vision_model, tokenizer=container.vision_tokenizer, query=query, image_base_64=image_base64,
|
||||
device=tool_config["device"])
|
||||
ans = vqa_run(model=container.vision_model,
|
||||
tokenizer=container.vision_tokenizer,
|
||||
query=query + "(with grounding)",
|
||||
image_base_64=image_base64,
|
||||
device=tool_config["device"])
|
||||
print(ans)
|
||||
image_new_base64 = draw_box_on_existing_image(container.metadata["images"][0], ans)
|
||||
|
||||
# Markers
|
||||
# start_marker = "Next Action:draw_box_on_existing_image
|
||||
# end_marker = "Grounded Operation:"
|
||||
# ans = extract_between_markers(ans, start_marker, end_marker)
|
||||
|
||||
return ans
|
||||
else:
|
||||
return "No Image, Please Try Again"
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class ConversationCallbackHandler(BaseCallbackHandler):
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
# 如果想存更多信息,则prompts 也需要持久化
|
||||
# TODO 如果想存更多信息,则 prompts 也需要持久化,不用的提示词需要特殊支持
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
|
||||
@ -1,25 +1,28 @@
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain.agents import initialize_agent, AgentType
|
||||
from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
from server.agent.agent_factory import initialize_glm3_agent
|
||||
from server.agent.tools_factory.tools_registry import all_tools
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from typing import AsyncIterable, Dict
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Union, AsyncIterable, Dict
|
||||
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from langchain.agents import initialize_agent, AgentType
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableBranch
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Union
|
||||
from server.chat.utils import History
|
||||
from langchain.prompts import PromptTemplate
|
||||
from server.utils import get_prompt_template
|
||||
|
||||
from server.agent.agent_factory import initialize_glm3_agent
|
||||
from server.agent.tools_factory.tools_registry import all_tools
|
||||
from server.agent.container import container
|
||||
|
||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||
from server.chat.utils import History
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
|
||||
from server.agent.container import container
|
||||
|
||||
|
||||
def create_models_from_config(configs, callbacks):
|
||||
if configs is None:
|
||||
@ -47,6 +50,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
memory = None
|
||||
chat_prompt = None
|
||||
container.metadata = metadata
|
||||
|
||||
if history:
|
||||
history = [History.from_data(h) for h in history]
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
@ -102,23 +106,35 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
|
||||
conversation_id: str = Body("", description="对话框ID"),
|
||||
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
|
||||
history: Union[int, List[History]] = Body([],
|
||||
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
history: Union[int, List[History]] = Body(
|
||||
[],
|
||||
description="历史对话,设为一个整数可以从数据库中读取历史消息",
|
||||
examples=[
|
||||
[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}
|
||||
]
|
||||
]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_config: Dict = Body({}, description="LLM 模型配置。"),
|
||||
model_config: Dict = Body({}, description="LLM 模型配置"),
|
||||
tool_config: Dict = Body({}, description="工具配置"),
|
||||
):
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
message_id = add_message_to_db(chat_type="llm_chat", query=query,
|
||||
conversation_id=conversation_id) if conversation_id else None
|
||||
message_id = add_message_to_db(
|
||||
chat_type="llm_chat",
|
||||
query=query,
|
||||
conversation_id=conversation_id
|
||||
) if conversation_id else None
|
||||
|
||||
callback = CustomAsyncIteratorCallbackHandler()
|
||||
callbacks = [callback]
|
||||
|
||||
# 从配置中选择模型
|
||||
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
|
||||
|
||||
# 从配置中选择工具
|
||||
tools = [tool for tool in all_tools if tool.name in tool_config]
|
||||
|
||||
# 构建完整的Chain
|
||||
@ -131,7 +147,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
history_len=history_len,
|
||||
metadata=metadata)
|
||||
|
||||
# 执行完整的Chain
|
||||
# Execute Chain
|
||||
|
||||
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
|
||||
if stream:
|
||||
async for chunk in callback.aiter():
|
||||
@ -166,12 +183,22 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
text += data["llm_token"]
|
||||
if tool_info:
|
||||
yield json.dumps(
|
||||
{"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
{
|
||||
"text": text,
|
||||
"agent_action": tool_info,
|
||||
"agent_finish": agent_finish,
|
||||
"message_id": message_id
|
||||
},
|
||||
ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
yield json.dumps(
|
||||
{"text": text, "message_id": message_id},
|
||||
ensure_ascii=False)
|
||||
{
|
||||
"text": text,
|
||||
"message_id": message_id
|
||||
},
|
||||
ensure_ascii=False
|
||||
)
|
||||
await task
|
||||
|
||||
return EventSourceResponse(chat_iterator())
|
||||
|
||||
@ -22,7 +22,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
|
||||
#todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||
#TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
|
||||
async def completion_iterator(query: str,
|
||||
model_name: str = None,
|
||||
prompt_name: str = prompt_name,
|
||||
|
||||
@ -49,10 +49,6 @@ for model_category in LLM_MODEL_CONFIG.values():
|
||||
|
||||
all_model_names_list = list(all_model_names)
|
||||
|
||||
@deprecated(
|
||||
since="0.3.0",
|
||||
message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动,0.2.x中相关功能将废弃",
|
||||
removal="0.3.0")
|
||||
def create_controller_app(
|
||||
dispatch_method: str,
|
||||
log_level: str = "INFO",
|
||||
@ -111,6 +107,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
worker_addr=args.worker_address)
|
||||
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
|
||||
|
||||
# 本地模型
|
||||
else:
|
||||
from configs.model_config import VLLM_MODEL_DICT
|
||||
@ -876,7 +873,9 @@ async def start_main_server():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
create_tables()
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
@ -887,4 +886,3 @@ if __name__ == "__main__":
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(start_main_server())
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
|
||||
from server.knowledge_base.utils import LOADER_DICT
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from PIL import Image
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
@ -206,7 +205,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
|
||||
|
||||
# files = st.file_uploader("上传附件",accept_multiple_files=False)
|
||||
# type=[i for ls in LOADER_DICT.values() for i in ls],)
|
||||
# type=[i for ls in LOADER_DICT.values() for i in ls],)
|
||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||
|
||||
@ -252,20 +251,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
history = get_messages_history(
|
||||
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
|
||||
chat_box.user_say(prompt)
|
||||
if files_upload["images"]:
|
||||
# 显示第一个上传的图像
|
||||
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["videos"]:
|
||||
# 显示第一个上传的视频
|
||||
st.markdown(
|
||||
f'<video width="320" height="240" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["audios"]:
|
||||
# 播放第一个上传的音频
|
||||
st.markdown(
|
||||
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||
unsafe_allow_html=True)
|
||||
if files_upload:
|
||||
if files_upload["images"]:
|
||||
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["videos"]:
|
||||
st.markdown(
|
||||
f'<video width="400" height="300" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["audios"]:
|
||||
st.markdown(
|
||||
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
@ -307,10 +304,16 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
text = chunk
|
||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=message_id,
|
||||
on_submit=on_feedback,
|
||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
if os.path.exists("tmp/image.jpg"):
|
||||
with open("tmp/image.jpg", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode()
|
||||
img_tag = f'<img src="data:image/jpeg;base64,{encoded_string}" width="300">'
|
||||
st.markdown(img_tag, unsafe_allow_html=True)
|
||||
os.remove("tmp/image.jpg")
|
||||
# chat_box.show_feedback(**feedback_kwargs,
|
||||
# key=message_id,
|
||||
# on_submit=on_feedback,
|
||||
# kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
|
||||
# elif dialogue_mode == "文件对话":
|
||||
# if st.session_state["file_chat_id"] is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user