支持多模态Grounding

1. 美化了chat的代码
2. 支持视觉工具输出Grounding任务
3. 完善工具调用的流程
This commit is contained in:
zR 2023-12-10 21:27:20 +08:00 committed by liunux4odoo
parent 03891cc27a
commit d44ce6ce21
12 changed files with 163 additions and 73 deletions

View File

View File

@ -5,4 +5,4 @@ from .server_config import *
from .prompt_config import *
VERSION = "v0.2.10"
VERSION = "v0.3.0-preview"

View File

@ -1,4 +1,6 @@
# 用于批量将configs下的.example文件复制并命名为.py文件
"""
用于批量将configs下的.example文件复制并命名为.py文件
"""
import os
import shutil

View File

@ -1,6 +1,6 @@
# API requirements
langchain>=0.0.348
langchain>=0.0.350
langchain-experimental>=0.0.42
pydantic==1.10.13
fschat==0.2.35

View File

@ -1,6 +1,6 @@
# API requirements
langchain>=0.0.346
langchain>=0.0.350
langchain-experimental>=0.0.42
pydantic==1.10.13
fschat==0.2.35

View File

@ -8,8 +8,6 @@ class ModelContainer:
self.model = None
self.metadata = None
self.metadata_response = None
self.vision_model = None
self.vision_tokenizer = None
self.audio_tokenizer = None
@ -29,12 +27,15 @@ class ModelContainer:
if TOOL_CONFIG["aqa_processor"]["use"]:
self.audio_tokenizer = AutoTokenizer.from_pretrained(
TOOL_CONFIG["aqa_processor"]["tokenizer_path"],
trust_remote_code=True)
trust_remote_code=True
)
self.audio_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=TOOL_CONFIG["aqa_processor"]["model_path"],
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True).to(TOOL_CONFIG["aqa_processor"]["device"]).eval()
trust_remote_code=True).to(
TOOL_CONFIG["aqa_processor"]["device"]
).eval()
container = ModelContainer()

View File

@ -4,9 +4,59 @@ Method Use cogagent to generate response for a given image and query.
import base64
from io import BytesIO
import torch
from PIL import Image
from PIL import Image, ImageDraw
from pydantic import BaseModel, Field
from configs import TOOL_CONFIG
import re
from server.agent.container import container
def extract_between_markers(text, start_marker, end_marker):
"""
Extracts and returns the portion of the text that is between 'start_marker' and 'end_marker'.
"""
start = text.find(start_marker)
end = text.find(end_marker, start)
if start != -1 and end != -1:
# Extract and return the text between the markers, without including the markers themselves
return text[start + len(start_marker):end].strip()
else:
return "Text not found between the specified markers"
def draw_box_on_existing_image(base64_image, text):
"""
在已有的Base64编码的图片上根据Grounded Operation中的坐标信息绘制矩形框
假设坐标是经过缩放的比例坐标
"""
# 解码并打开Base64编码的图片
img = Image.open(BytesIO(base64.b64decode(base64_image)))
draw = ImageDraw.Draw(img)
# 提取“Grounded Operation”后的坐标
pattern = r"\[\[(\d+),(\d+),(\d+),(\d+)\]\]"
match = re.search(pattern, text)
if not match:
return None
coords = tuple(map(int, match.groups()))
scaled_coords = (
int(coords[0] * 0.001 * img.width),
int(coords[1] * 0.001 * img.height),
int(coords[2] * 0.001 * img.width),
int(coords[3] * 0.001 * img.height)
)
draw.rectangle(scaled_coords, outline="red", width=3)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img.save("tmp/image.jpg")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return img_base64
def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", max_length=2048, top_p=0.9,
temperature=1.0):
@ -33,8 +83,9 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
'cross_images': [[inputs['cross_images'][0].to(device).to(torch.bfloat16)]] if inputs[
'cross_images'] else None,
}
gen_kwargs = {"max_length": max_length,
"temperature": temperature,
# "temperature": temperature,
"top_p": top_p,
"do_sample": False}
with torch.no_grad():
@ -47,15 +98,23 @@ def vqa_run(model, tokenizer, image_base_64, query, history=[], device="cuda", m
def vqa_processor(query: str):
from server.agent.container import container
tool_config = TOOL_CONFIG["vqa_processor"]
# model, tokenizer = load_model(model_path=tool_config["model_path"],
# tokenizer_path=tool_config["tokenizer_path"],
# device=tool_config["device"])
if container.metadata["images"]:
image_base64 = container.metadata["images"][0]
return vqa_run(model=container.vision_model, tokenizer=container.vision_tokenizer, query=query, image_base_64=image_base64,
device=tool_config["device"])
ans = vqa_run(model=container.vision_model,
tokenizer=container.vision_tokenizer,
query=query + "(with grounding)",
image_base_64=image_base64,
device=tool_config["device"])
print(ans)
image_new_base64 = draw_box_on_existing_image(container.metadata["images"][0], ans)
# Markers
# start_marker = "Next Action:draw_box_on_existing_image
# end_marker = "Grounded Operation:"
# ans = extract_between_markers(ans, start_marker, end_marker)
return ans
else:
return "No Image, Please Try Again"

View File

@ -23,7 +23,7 @@ class ConversationCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
# 如果想存更多信息则prompts 也需要持久化
# TODO 如果想存更多信息,则 prompts 也需要持久化,不用的提示词需要特殊支持
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:

View File

@ -1,25 +1,28 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from langchain.agents import initialize_agent, AgentType
from langchain_core.callbacks import AsyncCallbackManager, BaseCallbackManager
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableBranch
from server.agent.agent_factory import initialize_glm3_agent
from server.agent.tools_factory.tools_registry import all_tools
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
from typing import AsyncIterable, Dict
import asyncio
import json
from typing import List, Union, AsyncIterable, Dict
from fastapi import Body
from fastapi.responses import StreamingResponse
from langchain.agents import initialize_agent, AgentType
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableBranch
from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Union
from server.chat.utils import History
from langchain.prompts import PromptTemplate
from server.utils import get_prompt_template
from server.agent.agent_factory import initialize_glm3_agent
from server.agent.tools_factory.tools_registry import all_tools
from server.agent.container import container
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.chat.utils import History
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import Status, CustomAsyncIteratorCallbackHandler
from server.agent.container import container
def create_models_from_config(configs, callbacks):
if configs is None:
@ -47,6 +50,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
memory = None
chat_prompt = None
container.metadata = metadata
if history:
history = [History.from_data(h) for h in history]
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
@ -102,23 +106,35 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
metadata: dict = Body({}, description="附件,可能是图像或者其他功能", examples=[]),
conversation_id: str = Body("", description="对话框ID"),
history_len: int = Body(-1, description="从数据库中取历史消息的数量"),
history: Union[int, List[History]] = Body([],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
history: Union[int, List[History]] = Body(
[],
description="历史对话,设为一个整数可以从数据库中读取历史消息",
examples=[
[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}
]
]
),
stream: bool = Body(False, description="流式输出"),
model_config: Dict = Body({}, description="LLM 模型配置。"),
model_config: Dict = Body({}, description="LLM 模型配置"),
tool_config: Dict = Body({}, description="工具配置"),
):
async def chat_iterator() -> AsyncIterable[str]:
message_id = add_message_to_db(chat_type="llm_chat", query=query,
conversation_id=conversation_id) if conversation_id else None
message_id = add_message_to_db(
chat_type="llm_chat",
query=query,
conversation_id=conversation_id
) if conversation_id else None
callback = CustomAsyncIteratorCallbackHandler()
callbacks = [callback]
# 从配置中选择模型
models, prompts = create_models_from_config(callbacks=callbacks, configs=model_config)
# 从配置中选择工具
tools = [tool for tool in all_tools if tool.name in tool_config]
# 构建完整的Chain
@ -131,7 +147,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
history_len=history_len,
metadata=metadata)
# 执行完整的Chain
# Execute Chain
task = asyncio.create_task(wrap_done(full_chain.ainvoke({"input": query}, callbacks=callbacks), callback.done))
if stream:
async for chunk in callback.aiter():
@ -166,12 +183,22 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
text += data["llm_token"]
if tool_info:
yield json.dumps(
{"text": text, "agent_action": tool_info, "agent_finish": agent_finish, "message_id": message_id},
ensure_ascii=False)
{
"text": text,
"agent_action": tool_info,
"agent_finish": agent_finish,
"message_id": message_id
},
ensure_ascii=False
)
else:
yield json.dumps(
{"text": text, "message_id": message_id},
ensure_ascii=False)
{
"text": text,
"message_id": message_id
},
ensure_ascii=False
)
await task
return EventSourceResponse(chat_iterator())

View File

@ -22,7 +22,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
#todo 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理
#TODO: 因ApiModelWorker 默认是按chat处理的会对params["prompt"] 解析为messages因此ApiModelWorker 使用时需要有相应处理
async def completion_iterator(query: str,
model_name: str = None,
prompt_name: str = prompt_name,

View File

@ -49,10 +49,6 @@ for model_category in LLM_MODEL_CONFIG.values():
all_model_names_list = list(all_model_names)
@deprecated(
since="0.3.0",
message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动0.2.x中相关功能将废弃",
removal="0.3.0")
def create_controller_app(
dispatch_method: str,
log_level: str = "INFO",
@ -111,6 +107,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
worker_addr=args.worker_address)
# sys.modules["fastchat.serve.base_model_worker"].worker = worker
sys.modules["fastchat.serve.base_model_worker"].logger.setLevel(log_level)
# 本地模型
else:
from configs.model_config import VLLM_MODEL_DICT
@ -876,7 +873,9 @@ async def start_main_server():
if __name__ == "__main__":
create_tables()
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
@ -887,4 +886,3 @@ if __name__ == "__main__":
asyncio.set_event_loop(loop)
loop.run_until_complete(start_main_server())

View File

@ -14,7 +14,6 @@ from configs import (TOOL_CONFIG, LLM_MODEL_CONFIG)
from server.knowledge_base.utils import LOADER_DICT
import uuid
from typing import List, Dict
from PIL import Image
chat_box = ChatBox(
assistant_avatar=os.path.join(
@ -206,7 +205,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'][llm_model]
# files = st.file_uploader("上传附件",accept_multiple_files=False)
# type=[i for ls in LOADER_DICT.values() for i in ls],)
# type=[i for ls in LOADER_DICT.values() for i in ls],)
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
@ -252,20 +251,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
history = get_messages_history(
model_config["llm_model"][next(iter(model_config["llm_model"]))]["history_len"])
chat_box.user_say(prompt)
if files_upload["images"]:
# 显示第一个上传的图像
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
unsafe_allow_html=True)
elif files_upload["videos"]:
# 显示第一个上传的视频
st.markdown(
f'<video width="320" height="240" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
unsafe_allow_html=True)
elif files_upload["audios"]:
# 播放第一个上传的音频
st.markdown(
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
unsafe_allow_html=True)
if files_upload:
if files_upload["images"]:
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
unsafe_allow_html=True)
elif files_upload["videos"]:
st.markdown(
f'<video width="400" height="300" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
unsafe_allow_html=True)
elif files_upload["audios"]:
st.markdown(
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
unsafe_allow_html=True)
chat_box.ai_say("正在思考...")
text = ""
@ -307,10 +304,16 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text = chunk
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
chat_box.show_feedback(**feedback_kwargs,
key=message_id,
on_submit=on_feedback,
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode()
img_tag = f'<img src="data:image/jpeg;base64,{encoded_string}" width="300">'
st.markdown(img_tag, unsafe_allow_html=True)
os.remove("tmp/image.jpg")
# chat_box.show_feedback(**feedback_kwargs,
# key=message_id,
# on_submit=on_feedback,
# kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
# elif dialogue_mode == "文件对话":
# if st.session_state["file_chat_id"] is None: