From d44ce6ce217b20511653111f978ca7bf70e8798e Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sun, 10 Dec 2023 21:27:20 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=A4=9A=E6=A8=A1=E6=80=81Gr?= =?UTF-8?q?ounding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 美化了chat的代码 2. 支持视觉工具输出Grounding任务 3. 完善工具调用的流程 --- common/__init__.py | 0 configs/__init__.py | 2 +- copy_config_example.py | 4 +- requirements.txt | 2 +- requirements_api.txt | 2 +- server/agent/container.py | 9 +- .../agent/tools_factory/vision_factory/vqa.py | 75 ++++++++++++++-- .../conversation_callback_handler.py | 2 +- server/chat/chat.py | 87 ++++++++++++------- server/chat/completion.py | 2 +- startup.py | 8 +- webui_pages/dialogue/dialogue.py | 43 ++++----- 12 files changed, 163 insertions(+), 73 deletions(-) delete mode 100644 common/__init__.py diff --git a/common/__init__.py b/common/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/configs/__init__.py b/configs/__init__.py index e4f5f1d0..86f8dcf7 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -5,4 +5,4 @@ from .server_config import * from .prompt_config import * -VERSION = "v0.2.10" +VERSION = "v0.3.0-preview" diff --git a/copy_config_example.py b/copy_config_example.py index 91f34542..ad3bb2c9 100644 --- a/copy_config_example.py +++ b/copy_config_example.py @@ -1,4 +1,6 @@ -# 用于批量将configs下的.example文件复制并命名为.py文件 +""" +用于批量将configs下的.example文件复制并命名为.py文件 +""" import os import shutil diff --git a/requirements.txt b/requirements.txt index 7a99abc4..e55f9ff6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # API requirements -langchain>=0.0.348 +langchain>=0.0.350 langchain-experimental>=0.0.42 pydantic==1.10.13 fschat==0.2.35 diff --git a/requirements_api.txt b/requirements_api.txt index de38487b..801cea1d 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,6 +1,6 @@ # API requirements -langchain>=0.0.346 +langchain>=0.0.350 langchain-experimental>=0.0.42 pydantic==1.10.13 fschat==0.2.35 diff --git a/server/agent/container.py b/server/agent/container.py index 59aaa69a..7623f1a2 100644 --- a/server/agent/container.py +++ b/server/agent/container.py @@ -8,8 +8,6 @@ class ModelContainer: self.model = None self.metadata = None - self.metadata_response = None - self.vision_model = None self.vision_tokenizer = None self.audio_tokenizer = None @@ -29,12 +27,15 @@ class ModelContainer: if TOOL_CONFIG["aqa_processor"]["use"]: self.audio_tokenizer = AutoTokenizer.from_pretrained( TOOL_CONFIG["aqa_processor"]["tokenizer_path"], - trust_remote_code=True) + trust_remote_code=True + ) self.audio_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"], torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, - trust_remote_code=True).to(TOOL_CONFIG["aqa_processor"]["device"]).eval() + trust_remote_code=True).to( + TOOL_CONFIG["aqa_processor"]["device"] + ).eval() container = ModelContainer() diff --git a/server/agent/tools_factory/vision_factory/vqa.py b/server/agent/tools_factory/vision_factory/vqa.py index fd7aac13..af6691b1 100644 --- a/server/agent/tools_factory/vision_factory/vqa.py +++ b/server/agent/tools_factory/vision_factory/vqa.py @@ -4,9 +4,59 @@ Method Use cogagent to generate response for a given image and query. import base64 from io import BytesIO import torch -from PIL import Image +from PIL import Image, ImageDraw from pydantic import BaseModel, Field from configs import TOOL_CONFIG +import re +from server.agent.container import container + + +def extract_between_markers(text, start_marker, end_marker): + """ + Extracts and returns the portion of the text that is between 'start_marker' and 'end_marker'. + """ + start = text.find(start_marker) + end = text.find(end_marker, start) + + if start != -1 and end != -1: + # Extract and return the text between the markers, without including the markers themselves + return text[start + len(start_marker):end].strip() + else: + return "Text not found between the specified markers" + + + +def draw_box_on_existing_image(base64_image, text): + """ + 在已有的Base64编码的图片上根据“Grounded Operation”中的坐标信息绘制矩形框。 + 假设坐标是经过缩放的比例坐标。 + """ + # 解码并打开Base64编码的图片 + img = Image.open(BytesIO(base64.b64decode(base64_image))) + draw = ImageDraw.Draw(img) + + # 提取“Grounded Operation”后的坐标 + pattern = r"\[\[(\d+),(\d+),(\d+),(\d+)\]\]" + match = re.search(pattern, text) + if not match: + return None + + coords = tuple(map(int, match.groups())) + scaled_coords = ( + int(coords[0] * 0.001 * img.width), + int(coords[1] * 0.001 * img.height), + int(coords[2] * 0.001 * img.width), + int(coords[3] * 0.001 * img.height) + ) + draw.rectangle(scaled_coords, outline="red", width=3) + + buffered = BytesIO() + img.save(buffered, format="JPEG") + img.save("tmp/image.jpg") + img_base64 = base64.b64encode(buffered.getvalue()).decode() + + return img_base64 + def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9, temperature=1.0): @@ -33,8 +83,9 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m 'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[ 'cross_images'] else None, } + gen_kwargs = {"max_length": max_length, - "temperature": temperature, + # "temperature": temperature, "top_p": top_p, "do_sample": False} with torch.no_grad(): @@ -47,15 +98,23 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m def vqa_processor(query: str): - from server.agent.container import container tool_config = TOOL_CONFIG["vqa_processor"] - # model, tokenizer = load_model(model_path=tool_config["model_path"], - # tokenizer_path=tool_config["tokenizer_path"], - # device=tool_config["device"]) if container.metadata["images"]: image_base64 = container.metadata["images"][0] - return vqa_run(model=container.vision_model, tokenizer=container.vision_tokenizer, query=query, image_base_64=image_base64, - device=tool_config["device"]) + ans = vqa_run(model=container.vision_model, + tokenizer=container.vision_tokenizer, + query=query + "(with grounding)", + image_base_64=image_base64, + device=tool_config["device"]) + print(ans) + image_new_base64 = draw_box_on_existing_image(container.metadata["images"][0], ans) + + # Markers + # start_marker = "Next Action:draw_box_on_existing_image + # end_marker = "Grounded Operation:" + # ans = extract_between_markers(ans, start_marker, end_marker) + + return ans else: return "No Image, Please Try Again" diff --git a/server/callback_handler/conversation_callback_handler.py b/server/callback_handler/conversation_callback_handler.py index 8f09b40d..ab926a55 100644 --- a/server/callback_handler/conversation_callback_handler.py +++ b/server/callback_handler/conversation_callback_handler.py @@ -23,7 +23,7 @@ class ConversationCallbackHandler(BaseCallbackHandler): def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: - # 如果想存更多信息,则prompts 也需要持久化 + # TODO 如果想存更多信息,则 prompts 也需要持久化,不用的提示词需要特殊支持 pass def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: diff --git a/server/chat/chat.py b/server/chat/chat.py index 33eaa6cf..7afdeacb 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,25 +1,28 @@ -from fastapi import Body -from fastapi.responses import StreamingResponse -from langchain.agents import initialize_agent, AgentType -from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager -from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnableBranch -from server.agent.agent_factory import initialize_glm3_agent -from server.agent.tools_factory.tools_registry import all_tools -from server.utils import wrap_done, get_ChatOpenAI -from langchain.chains import LLMChain -from typing import AsyncIterable, Dict import asyncio import json +from typing import List, Union, AsyncIterable, Dict + +from fastapi import Body +from fastapi.responses import StreamingResponse + +from langchain.agents import initialize_agent, AgentType +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnableBranch +from langchain.chains import LLMChain from langchain.prompts.chat import ChatPromptTemplate -from typing import List, Union -from server.chat.utils import History from langchain.prompts import PromptTemplate -from server.utils import get_prompt_template + +from server.agent.agent_factory import initialize_glm3_agent +from server.agent.tools_factory.tools_registry import all_tools +from server.agent.container import container + +from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template +from server.chat.utils import History from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.db.repository import add_message_to_db from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler -from server.agent.container import container + def create_models_from_config(configs, callbacks): if configs is None: @@ -47,6 +50,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks memory = None chat_prompt = None container.metadata = metadata + if history: history = [History.from_data(h) for h in history] input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) @@ -102,23 +106,35 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]), conversation_id: str = Body("", description="对话框ID"), history_len: int = Body(-1, description="从数据库中取历史消息的数量"), - history: Union[int, List[History]] = Body([], - description="历史对话,设为一个整数可以从数据库中读取历史消息", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", "content": "虎头虎脑"}]] - ), + history: Union[int, List[History]] = Body( + [], + description="历史对话,设为一个整数可以从数据库中读取历史消息", + examples=[ + [ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", "content": "虎头虎脑"} + ] + ] + ), stream: bool = Body(False, description="流式输出"), - model_config: Dict = Body({}, description="LLM 模型配置。"), + model_config: Dict = Body({}, description="LLM 模型配置"), tool_config: Dict = Body({}, description="工具配置"), ): async def chat_iterator() -> AsyncIterable[str]: - message_id = add_message_to_db(chat_type="llm_chat", query=query, - conversation_id=conversation_id) if conversation_id else None + message_id = add_message_to_db( + chat_type="llm_chat", + query=query, + conversation_id=conversation_id + ) if conversation_id else None + callback = CustomAsyncIteratorCallbackHandler() callbacks = [callback] + + # 从配置中选择模型 models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config) + + # 从配置中选择工具 tools = [tool for tool in all_tools if tool.name in tool_config] # 构建完整的Chain @@ -131,7 +147,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 history_len=history_len, metadata=metadata) - # 执行完整的Chain + # Execute Chain + task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done)) if stream: async for chunk in callback.aiter(): @@ -166,12 +183,22 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 text += data["llm_token"] if tool_info: yield json.dumps( - {"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id}, - ensure_ascii=False) + { + "text": text, + "agent_action": tool_info, + "agent_finish": agent_finish, + "message_id": message_id + }, + ensure_ascii=False + ) else: yield json.dumps( - {"text": text, "message_id": message_id}, - ensure_ascii=False) + { + "text": text, + "message_id": message_id + }, + ensure_ascii=False + ) await task return EventSourceResponse(chat_iterator()) diff --git a/server/chat/completion.py b/server/chat/completion.py index f93abce8..bddf07e3 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -22,7 +22,7 @@ async def completion(query: str = Body(..., description="用户输入", examples description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): - #todo 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 + #TODO: 因ApiModelWorker 默认是按chat处理的,会对params["prompt"] 解析为messages,因此ApiModelWorker 使用时需要有相应处理 async def completion_iterator(query: str, model_name: str = None, prompt_name: str = prompt_name, diff --git a/startup.py b/startup.py index e3696a80..6a8856ec 100644 --- a/startup.py +++ b/startup.py @@ -49,10 +49,6 @@ for model_category in LLM_MODEL_CONFIG.values(): all_model_names_list = list(all_model_names) -@deprecated( - since="0.3.0", - message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动,0.2.x中相关功能将废弃", - removal="0.3.0") def create_controller_app( dispatch_method: str, log_level: str = "INFO", @@ -111,6 +107,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: worker_addr=args.worker_address) # sys.modules["fastchat.serve.base_model_worker"].worker = worker sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level) + # 本地模型 else: from configs.model_config import VLLM_MODEL_DICT @@ -876,7 +873,9 @@ async def start_main_server(): if __name__ == "__main__": + create_tables() + if sys.version_info < (3, 10): loop = asyncio.get_event_loop() else: @@ -887,4 +886,3 @@ if __name__ == "__main__": asyncio.set_event_loop(loop) loop.run_until_complete(start_main_server()) - diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index b98ff24c..34178e66 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -14,7 +14,6 @@ from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG) from server.knowledge_base.utils import LOADER_DICT import uuid from typing import List, Dict -from PIL import Image chat_box = ChatBox( assistant_avatar=os.path.join( @@ -206,7 +205,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model] # files = st.file_uploader("上传附件",accept_multiple_files=False) - # type=[i for ls in LOADER_DICT.values() for i in ls],) + # type=[i for ls in LOADER_DICT.values() for i in ls],) uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False) files_upload = process_files(files=[uploaded_file]) if uploaded_file else None @@ -252,20 +251,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): history = get_messages_history( model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"]) chat_box.user_say(prompt) - if files_upload["images"]: - # 显示第一个上传的图像 - st.markdown(f'', - unsafe_allow_html=True) - elif files_upload["videos"]: - # 显示第一个上传的视频 - st.markdown( - f'', - unsafe_allow_html=True) - elif files_upload["audios"]: - # 播放第一个上传的音频 - st.markdown( - f'', - unsafe_allow_html=True) + if files_upload: + if files_upload["images"]: + st.markdown(f'', + unsafe_allow_html=True) + elif files_upload["videos"]: + st.markdown( + f'', + unsafe_allow_html=True) + elif files_upload["audios"]: + st.markdown( + f'', + unsafe_allow_html=True) chat_box.ai_say("正在思考...") text = "" @@ -307,10 +304,16 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): text = chunk chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) - chat_box.show_feedback(**feedback_kwargs, - key=message_id, - on_submit=on_feedback, - kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) + if os.path.exists("tmp/image.jpg"): + with open("tmp/image.jpg", "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()).decode() + img_tag = f'' + st.markdown(img_tag, unsafe_allow_html=True) + os.remove("tmp/image.jpg") + # chat_box.show_feedback(**feedback_kwargs, + # key=message_id, + # on_submit=on_feedback, + # kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) # elif dialogue_mode == "文件对话": # if st.session_state["file_chat_id"] is None: