From d44ce6ce217b20511653111f978ca7bf70e8798e Mon Sep 17 00:00:00 2001
From: zR <2448370773@qq.com>
Date: Sun, 10 Dec 2023 21:27:20 +0800
Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=A4=9A=E6=A8=A1=E6=80=81Gr?=
=?UTF-8?q?ounding?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
1. 美化了chat的代码
2. 支持视觉工具输出Grounding任务
3. 完善工具调用的流程
---
common/__init__.py | 0
configs/__init__.py | 2 +-
copy_config_example.py | 4 +-
requirements.txt | 2 +-
requirements_api.txt | 2 +-
server/agent/container.py | 9 +-
.../agent/tools_factory/vision_factory/vqa.py | 75 ++++++++++++++--
.../conversation_callback_handler.py | 2 +-
server/chat/chat.py | 87 ++++++++++++-------
server/chat/completion.py | 2 +-
startup.py | 8 +-
webui_pages/dialogue/dialogue.py | 43 ++++-----
12 files changed, 163 insertions(+), 73 deletions(-)
delete mode 100644 common/__init__.py
diff --git a/common/__init__.py b/common/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/configs/__init__.py b/configs/__init__.py
index e4f5f1d0..86f8dcf7 100644
--- a/configs/__init__.py
+++ b/configs/__init__.py
@@ -5,4 +5,4 @@ from .server_config import *
from .prompt_config import *
-VERSION = "v0.2.10"
+VERSION = "v0.3.0-preview"
diff --git a/copy_config_example.py b/copy_config_example.py
index 91f34542..ad3bb2c9 100644
--- a/copy_config_example.py
+++ b/copy_config_example.py
@@ -1,4 +1,6 @@
-# 用于批量将configs下的.example文件复制并命名为.py文件
+"""
+用于批量将configs下的.example文件复制并命名为.py文件
+"""
import os
import shutil
diff --git a/requirements.txt b/requirements.txt
index 7a99abc4..e55f9ff6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
# API requirements
-langchain>=0.0.348
+langchain>=0.0.350
langchain-experimental>=0.0.42
pydantic==1.10.13
fschat==0.2.35
diff --git a/requirements_api.txt b/requirements_api.txt
index de38487b..801cea1d 100644
--- a/requirements_api.txt
+++ b/requirements_api.txt
@@ -1,6 +1,6 @@
# API requirements
-langchain>=0.0.346
+langchain>=0.0.350
langchain-experimental>=0.0.42
pydantic==1.10.13
fschat==0.2.35
diff --git a/server/agent/container.py b/server/agent/container.py
index 59aaa69a..7623f1a2 100644
--- a/server/agent/container.py
+++ b/server/agent/container.py
@@ -8,8 +8,6 @@ class ModelContainer:
self.model = None
self.metadata = None
- self.metadata_response = None
-
self.vision_model = None
self.vision_tokenizer = None
self.audio_tokenizer = None
@@ -29,12 +27,15 @@ class ModelContainer:
if TOOL_CONFIG["aqa_processor"]["use"]:
self.audio_tokenizer = AutoTokenizer.from_pretrained(
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
- trust_remote_code=True)
+ trust_remote_code=True
+ )
self.audio_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
- trust_remote_code=True).to(TOOL_CONFIG["aqa_processor"]["device"]).eval()
+ trust_remote_code=True).to(
+ TOOL_CONFIG["aqa_processor"]["device"]
+ ).eval()
container = ModelContainer()
diff --git a/server/agent/tools_factory/vision_factory/vqa.py b/server/agent/tools_factory/vision_factory/vqa.py
index fd7aac13..af6691b1 100644
--- a/server/agent/tools_factory/vision_factory/vqa.py
+++ b/server/agent/tools_factory/vision_factory/vqa.py
@@ -4,9 +4,59 @@ Method Use cogagent to generate response for a given image and query.
import base64
from io import BytesIO
import torch
-from PIL import Image
+from PIL import Image, ImageDraw
from pydantic import BaseModel, Field
from configs import TOOL_CONFIG
+import re
+from server.agent.container import container
+
+
+def extract_between_markers(text, start_marker, end_marker):
+ """
+ Extracts and returns the portion of the text that is between 'start_marker' and 'end_marker'.
+ """
+ start = text.find(start_marker)
+ end = text.find(end_marker, start)
+
+ if start != -1 and end != -1:
+ # Extract and return the text between the markers, without including the markers themselves
+ return text[start + len(start_marker):end].strip()
+ else:
+ return "Text not found between the specified markers"
+
+
+
+def draw_box_on_existing_image(base64_image, text):
+ """
+ 在已有的Base64编码的图片上根据“Grounded Operation”中的坐标信息绘制矩形框。
+ 假设坐标是经过缩放的比例坐标。
+ """
+ # 解码并打开Base64编码的图片
+ img = Image.open(BytesIO(base64.b64decode(base64_image)))
+ draw = ImageDraw.Draw(img)
+
+ # 提取“Grounded Operation”后的坐标
+ pattern = r"\[\[(\d+),(\d+),(\d+),(\d+)\]\]"
+ match = re.search(pattern, text)
+ if not match:
+ return None
+
+ coords = tuple(map(int, match.groups()))
+ scaled_coords = (
+ int(coords[0] * 0.001 * img.width),
+ int(coords[1] * 0.001 * img.height),
+ int(coords[2] * 0.001 * img.width),
+ int(coords[3] * 0.001 * img.height)
+ )
+ draw.rectangle(scaled_coords, outline="red", width=3)
+
+ buffered = BytesIO()
+ img.save(buffered, format="JPEG")
+ img.save("tmp/image.jpg")
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
+
+ return img_base64
+
def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9,
temperature=1.0):
@@ -33,8 +83,9 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[
'cross_images'] else None,
}
+
gen_kwargs = {"max_length": max_length,
- "temperature": temperature,
+ # "temperature": temperature,
"top_p": top_p,
"do_sample": False}
with torch.no_grad():
@@ -47,15 +98,23 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
def vqa_processor(query: str):
- from server.agent.container import container
tool_config = TOOL_CONFIG["vqa_processor"]
- # model, tokenizer = load_model(model_path=tool_config["model_path"],
- # tokenizer_path=tool_config["tokenizer_path"],
- # device=tool_config["device"])
if container.metadata["images"]:
image_base64 = container.metadata["images"][0]
- return vqa_run(model=container.vision_model, tokenizer=container.vision_tokenizer, query=query, image_base_64=image_base64,
- device=tool_config["device"])
+ ans = vqa_run(model=container.vision_model,
+ tokenizer=container.vision_tokenizer,
+ query=query + "(with grounding)",
+ image_base_64=image_base64,
+ device=tool_config["device"])
+ print(ans)
+ image_new_base64 = draw_box_on_existing_image(container.metadata["images"][0], ans)
+
+ # Markers
+ # start_marker = "Next Action:draw_box_on_existing_image
+ # end_marker = "Grounded Operation:"
+ # ans = extract_between_markers(ans, start_marker, end_marker)
+
+ return ans
else:
return "No Image, Please Try Again"
diff --git a/server/callback_handler/conversation_callback_handler.py b/server/callback_handler/conversation_callback_handler.py
index 8f09b40d..ab926a55 100644
--- a/server/callback_handler/conversation_callback_handler.py
+++ b/server/callback_handler/conversation_callback_handler.py
@@ -23,7 +23,7 @@ class ConversationCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
- # 如果想存更多信息,则prompts 也需要持久化
+ # TODO 如果想存更多信息,则 prompts 也需要持久化,不用的提示词需要特殊支持
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
diff --git a/server/chat/chat.py b/server/chat/chat.py
index 33eaa6cf..7afdeacb 100644
--- a/server/chat/chat.py
+++ b/server/chat/chat.py
@@ -1,25 +1,28 @@
-from fastapi import Body
-from fastapi.responses import StreamingResponse
-from langchain.agents import initialize_agent, AgentType
-from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager
-from langchain_core.output_parsers import StrOutputParser
-from langchain_core.runnables import RunnableBranch
-from server.agent.agent_factory import initialize_glm3_agent
-from server.agent.tools_factory.tools_registry import all_tools
-from server.utils import wrap_done, get_ChatOpenAI
-from langchain.chains import LLMChain
-from typing import AsyncIterable, Dict
import asyncio
import json
+from typing import List, Union, AsyncIterable, Dict
+
+from fastapi import Body
+from fastapi.responses import StreamingResponse
+
+from langchain.agents import initialize_agent, AgentType
+from langchain_core.callbacks import BaseCallbackManager
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.runnables import RunnableBranch
+from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
-from typing import List, Union
-from server.chat.utils import History
from langchain.prompts import PromptTemplate
-from server.utils import get_prompt_template
+
+from server.agent.agent_factory import initialize_glm3_agent
+from server.agent.tools_factory.tools_registry import all_tools
+from server.agent.container import container
+
+from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
+from server.chat.utils import History
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
-from server.agent.container import container
+
def create_models_from_config(configs, callbacks):
if configs is None:
@@ -47,6 +50,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
memory = None
chat_prompt = None
container.metadata = metadata
+
if history:
history = [History.from_data(h) for h in history]
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
@@ -102,23 +106,35 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
- history: Union[int, List[History]] = Body([],
- description="历史对话,设为一个整数可以从数据库中读取历史消息",
- examples=[[
- {"role": "user",
- "content": "我们来玩成语接龙,我先来,生龙活虎"},
- {"role": "assistant", "content": "虎头虎脑"}]]
- ),
+ history: Union[int, List[History]] = Body(
+ [],
+ description="历史对话,设为一个整数可以从数据库中读取历史消息",
+ examples=[
+ [
+ {"role": "user",
+ "content": "我们来玩成语接龙,我先来,生龙活虎"},
+ {"role": "assistant", "content": "虎头虎脑"}
+ ]
+ ]
+ ),
stream: bool = Body(False, description="流式输出"),
- model_config: Dict = Body({}, description="LLM 模型配置。"),
+ model_config: Dict = Body({}, description="LLM 模型配置"),
tool_config: Dict = Body({}, description="工具配置"),
):
async def chat_iterator() -> AsyncIterable[str]:
- message_id = add_message_to_db(chat_type="llm_chat", query=query,
- conversation_id=conversation_id) if conversation_id else None
+ message_id = add_message_to_db(
+ chat_type="llm_chat",
+ query=query,
+ conversation_id=conversation_id
+ ) if conversation_id else None
+
callback = CustomAsyncIteratorCallbackHandler()
callbacks = [callback]
+
+ # 从配置中选择模型
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
+
+ # 从配置中选择工具
tools = [tool for tool in all_tools if tool.name in tool_config]
# 构建完整的Chain
@@ -131,7 +147,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
history_len=history_len,
metadata=metadata)
- # 执行完整的Chain
+ # Execute Chain
+
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
if stream:
async for chunk in callback.aiter():
@@ -166,12 +183,22 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
text += data["llm_token"]
if tool_info:
yield json.dumps(
- {"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id},
- ensure_ascii=False)
+ {
+ "text": text,
+ "agent_action": tool_info,
+ "agent_finish": agent_finish,
+ "message_id": message_id
+ },
+ ensure_ascii=False
+ )
else:
yield json.dumps(
- {"text": text, "message_id": message_id},
- ensure_ascii=False)
+ {
+ "text": text,
+ "message_id": message_id
+ },
+ ensure_ascii=False
+ )
await task
return EventSourceResponse(chat_iterator())
diff --git a/server/chat/completion.py b/server/chat/completion.py
index f93abce8..bddf07e3 100644
--- a/server/chat/completion.py
+++ b/server/chat/completion.py
@@ -22,7 +22,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
- #todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
+ #TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str,
model_name: str = None,
prompt_name: str = prompt_name,
diff --git a/startup.py b/startup.py
index e3696a80..6a8856ec 100644
--- a/startup.py
+++ b/startup.py
@@ -49,10 +49,6 @@ for model_category in LLM_MODEL_CONFIG.values():
all_model_names_list = list(all_model_names)
-@deprecated(
- since="0.3.0",
- message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动,0.2.x中相关功能将废弃",
- removal="0.3.0")
def create_controller_app(
dispatch_method: str,
log_level: str = "INFO",
@@ -111,6 +107,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
worker_addr=args.worker_address)
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
+
# 本地模型
else:
from configs.model_config import VLLM_MODEL_DICT
@@ -876,7 +873,9 @@ async def start_main_server():
if __name__ == "__main__":
+
create_tables()
+
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
@@ -887,4 +886,3 @@ if __name__ == "__main__":
asyncio.set_event_loop(loop)
loop.run_until_complete(start_main_server())
-
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index b98ff24c..34178e66 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -14,7 +14,6 @@ from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
from server.knowledge_base.utils import LOADER_DICT
import uuid
from typing import List, Dict
-from PIL import Image
chat_box = ChatBox(
assistant_avatar=os.path.join(
@@ -206,7 +205,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
# files = st.file_uploader("上传附件",accept_multiple_files=False)
- # type=[i for ls in LOADER_DICT.values() for i in ls],)
+ # type=[i for ls in LOADER_DICT.values() for i in ls],)
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
@@ -252,20 +251,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
history = get_messages_history(
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
chat_box.user_say(prompt)
- if files_upload["images"]:
- # 显示第一个上传的图像
- st.markdown(f'
',
- unsafe_allow_html=True)
- elif files_upload["videos"]:
- # 显示第一个上传的视频
- st.markdown(
- f'',
- unsafe_allow_html=True)
- elif files_upload["audios"]:
- # 播放第一个上传的音频
- st.markdown(
- f'',
- unsafe_allow_html=True)
+ if files_upload:
+ if files_upload["images"]:
+ st.markdown(f'
',
+ unsafe_allow_html=True)
+ elif files_upload["videos"]:
+ st.markdown(
+ f'',
+ unsafe_allow_html=True)
+ elif files_upload["audios"]:
+ st.markdown(
+ f'',
+ unsafe_allow_html=True)
chat_box.ai_say("正在思考...")
text = ""
@@ -307,10 +304,16 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text = chunk
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
- chat_box.show_feedback(**feedback_kwargs,
- key=message_id,
- on_submit=on_feedback,
- kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
+ if os.path.exists("tmp/image.jpg"):
+ with open("tmp/image.jpg", "rb") as image_file:
+ encoded_string = base64.b64encode(image_file.read()).decode()
+ img_tag = f'
'
+ st.markdown(img_tag, unsafe_allow_html=True)
+ os.remove("tmp/image.jpg")
+ # chat_box.show_feedback(**feedback_kwargs,
+ # key=message_id,
+ # on_submit=on_feedback,
+ # kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
# elif dialogue_mode == "文件对话":
# if st.session_state["file_chat_id"] is None: