From 8063aab7a1c72b905b9ab35acd1dfa2d43a316f3 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Sat, 13 Jan 2024 13:54:03 +0800 Subject: [PATCH] =?UTF-8?q?webui=20=E6=94=AF=E6=8C=81=E6=96=87=E7=94=9F?= =?UTF-8?q?=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + configs/basic_config.py.example | 8 +++ server/agent/tools_factory/text2image.py | 32 ++++++++++-- server/api.py | 6 ++- server/chat/chat.py | 26 ++++++++-- server/chat/utils.py | 1 + server/utils.py | 7 +++ tests/test_qwen_agent.py | 62 ++++++++++++++++++++++-- webui_pages/dialogue/dialogue.py | 16 +++--- 9 files changed, 139 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index 848d0818..6a2ccef5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.log.* *.bak logs +/media/ /knowledge_base/* !/knowledge_base/samples /knowledge_base/samples/vector_store diff --git a/configs/basic_config.py.example b/configs/basic_config.py.example index a22fb977..1cb33fe5 100644 --- a/configs/basic_config.py.example +++ b/configs/basic_config.py.example @@ -23,6 +23,14 @@ LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") if not os.path.exists(LOG_PATH): os.mkdir(LOG_PATH) +# 模型生成内容(图片、视频、音频等)保存位置 +MEDIA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "media") +if not os.path.exists(MEDIA_PATH): + os.mkdir(MEDIA_PATH) + os.mkdir(os.path.join(MEDIA_PATH, "image")) + os.mkdir(os.path.join(MEDIA_PATH, "audio")) + os.mkdir(os.path.join(MEDIA_PATH, "video")) + # 临时文件目录,主要用于文件对话 BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat") if os.path.isdir(BASE_TEMP_DIR): diff --git a/server/agent/tools_factory/text2image.py b/server/agent/tools_factory/text2image.py index 02d7dd3a..c6983666 100644 --- a/server/agent/tools_factory/text2image.py +++ b/server/agent/tools_factory/text2image.py @@ -1,9 +1,17 @@ +import base64 +import json +import os from PIL import Image from typing import List +import uuid from langchain.agents import tool from langchain.pydantic_v1 import Field import openai +from pydantic.fields import FieldInfo + +from configs.basic_config import MEDIA_PATH +from server.utils import MsgType def get_image_model_config() -> dict: @@ -18,7 +26,7 @@ def get_image_model_config() -> dict: return config -@tool +@tool(return_direct=True) def text2images( prompt: str, n: int = Field(1, description="需生成图片的数量"), @@ -26,6 +34,14 @@ def text2images( height: int = Field(512, description="生成图片的高度"), ) -> List[str]: '''根据用户的描述生成图片''' + # workaround before langchain uprading + if isinstance(n, FieldInfo): + n = n.default + if isinstance(width, FieldInfo): + width = width.default + if isinstance(height, FieldInfo): + height = height.default + model_config = get_image_model_config() assert model_config is not None, "请正确配置文生图模型" @@ -37,13 +53,21 @@ def text2images( resp = client.images.generate(prompt=prompt, n=n, size=f"{width}*{height}", - response_format="url", + response_format="b64_json", model=model_config["model_name"], ) - return [x.url for x in resp.data] + images = [] + for x in resp.data: + uid = uuid.uuid4().hex + filename = f"image/{uid}.png" + with open(os.path.join(MEDIA_PATH, filename), "wb") as fp: + fp.write(base64.b64decode(x.b64_json)) + images.append(filename) + return json.dumps({"message_type": MsgType.IMAGE, "images": images}) if __name__ == "__main__": + from io import BytesIO from matplotlib import pyplot as plt from pathlib import Path import sys @@ -54,5 +78,7 @@ if __name__ == "__main__": params = text2images.args_schema.parse_obj({"prompt": prompt}).dict() print(params) image = text2images.invoke(params)[0] + buffer = BytesIO(base64.b64decode(image)) + image = Image.open(buffer) plt.imshow(image) plt.show() diff --git a/server/api.py b/server/api.py index 33954e85..8ecf9503 100644 --- a/server/api.py +++ b/server/api.py @@ -4,13 +4,14 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs import VERSION +from configs import VERSION, MEDIA_PATH from configs.model_config import NLTK_DATA_PATH from configs.server_config import OPEN_CROSS_DOMAIN import argparse import uvicorn from fastapi import Body from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles from starlette.responses import RedirectResponse from server.chat.chat import chat from server.chat.completion import completion @@ -125,6 +126,9 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): summary="将文本向量化,支持本地模型和在线模型", )(embed_texts_endpoint) + # 媒体文件 + app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") + def mount_knowledge_routes(app: FastAPI): from server.chat.file_chat import upload_temp_docs, file_chat diff --git a/server/chat/chat.py b/server/chat/chat.py index bedc988b..670cd770 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -14,11 +14,11 @@ from server.agent.agent_factory.agents_registry import agents_registry from server.agent.tools_factory.tools_registry import all_tools from server.agent.container import container -from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template +from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType from server.chat.utils import History from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.db.repository import add_message_to_db -from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler +from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus def create_models_from_config(configs, callbacks, stream): @@ -63,10 +63,11 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_msg]) + llm=models["llm_model"] + llm.callbacks = callbacks chain = LLMChain( prompt=chat_prompt, - llm=models["llm_model"], - callbacks=callbacks, + llm=llm, memory=memory ) classifier_chain = ( @@ -77,7 +78,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks if "action_model" in models and tools is not None: agent_executor = agents_registry( - llm=models["action_model"], + llm=llm, callbacks=callbacks, tools=tools, prompt=None, @@ -144,6 +145,21 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 async for chunk in callback.aiter(): data = json.loads(chunk) + if data["status"] == AgentStatus.tool_end: + try: + tool_output = json.loads(data["tool_output"]) + if message_type := tool_output.get("message_type"): + data["message_type"] = message_type + except: + ... + elif data["status"] == AgentStatus.agent_finish: + try: + tool_output = json.loads(data["text"]) + if message_type := tool_output.get("message_type"): + data["message_type"] = message_type + except: + ... + data.setdefault("message_type", MsgType.TEXT) data["message_id"] = message_id yield json.dumps(data, ensure_ascii=False) diff --git a/server/chat/utils.py b/server/chat/utils.py index dd3c3332..2fd82674 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -1,3 +1,4 @@ +from functools import lru_cache from pydantic import BaseModel, Field from langchain.prompts.chat import ChatMessagePromptTemplate from configs import logger, log_verbose diff --git a/server/utils.py b/server/utils.py index 0962b017..bba6e5eb 100644 --- a/server/utils.py +++ b/server/utils.py @@ -101,6 +101,13 @@ def get_OpenAI( return model +class MsgType: + TEXT = 1 + IMAGE = 2 + AUDIO = 3 + VIDEO = 4 + + class BaseResponse(BaseModel): code: int = pydantic.Field(200, description="API status code") msg: str = pydantic.Field("success", description="API status message") diff --git a/tests/test_qwen_agent.py b/tests/test_qwen_agent.py index 2763e103..bac8a4d3 100644 --- a/tests/test_qwen_agent.py +++ b/tests/test_qwen_agent.py @@ -3,6 +3,8 @@ from pathlib import Path sys.path.append(str(Path(__file__).parent.parent)) import asyncio +import json +from pprint import pprint from langchain.agents import AgentExecutor from langchain_openai.chat_models import ChatOpenAI # from langchain.chat_models.openai import ChatOpenAI @@ -12,8 +14,8 @@ from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_ag from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler from langchain import globals -globals.set_debug(True) -globals.set_verbose(True) +# globals.set_debug(True) +# globals.set_verbose(True) async def test1(): @@ -39,7 +41,7 @@ async def test1(): await ret -async def test2(): +async def test_server_chat(): from server.chat.chat import chat mc={'preprocess_model': { @@ -83,7 +85,57 @@ async def test2(): history_len=-1, history=[], stream=True)).body_iterator: - print(x) + pprint(x) -asyncio.run(test2()) +async def test_text2image(): + from server.chat.chat import chat + + mc={'preprocess_model': { + 'qwen-api': { + 'temperature': 0.4, + 'max_tokens': 2048, + 'history_len': 100, + 'prompt_name': 'default', + 'callbacks': False} + }, + 'llm_model': { + 'qwen-api': { + 'temperature': 0.9, + 'max_tokens': 4096, + 'history_len': 3, + 'prompt_name': 'default', + 'callbacks': True} + }, + 'action_model': { + 'qwen-api': { + 'temperature': 0.01, + 'max_tokens': 4096, + 'prompt_name': 'qwen', + 'callbacks': True} + }, + 'postprocess_model': { + 'qwen-api': { + 'temperature': 0.01, + 'max_tokens': 4096, + 'prompt_name': 'default', + 'callbacks': True} + }, + 'image_model': { + 'sd-turbo': {} + } + } + + tc={'text2images': {'use': True}} + + async for x in (await chat("draw a house",{}, + model_config=mc, + tool_config=tc, + conversation_id=None, + history_len=-1, + history=[], + stream=False)).body_iterator: + x = json.loads(x) + pprint(x) + +asyncio.run(test_text2image()) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 9c2f575e..f70490c3 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -12,6 +12,7 @@ import re import time from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG) from server.callback_handler.agent_callback_handler import AgentStatus +from server.utils import MsgType import uuid from typing import List, Dict @@ -168,7 +169,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): for k, v in config_models.get("online", {}).items(): if not v.get("provider") and k not in running_models and k in LLM_MODELS: available_models.append(k) - llm_models = running_models + available_models # + ["openai-api"] + llm_models = running_models + available_models cur_llm_model = st.session_state.get("cur_llm_model", default_model) if cur_llm_model in llm_models: index = llm_models.index(cur_llm_model) @@ -276,7 +277,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): text = "" text_action = "" element_index = 0 - for d in api.chat_chat(query=prompt, metadata=files_upload, history=history, @@ -288,6 +288,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): metadata = { "message_id": message_id, } + print(d) if d["status"] == AgentStatus.error: st.error(d["text"]) elif d["status"] == AgentStatus.agent_action: @@ -310,11 +311,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): elif d["status"] == AgentStatus.llm_end: chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) elif d["status"] == AgentStatus.agent_finish: - element_index += 1 + if d["message_type"] == MsgType.IMAGE: + for url in json.loads(d["text"]).get("images", []): + url = f"{api.base_url}/media/{url}" + chat_box.insert_msg(Image(url)) + chat_box.update_msg(element_index=element_index, expanded=False, state="complete") + else: + chat_box.insert_msg(Markdown(d["text"], expanded=True)) - # print(d["text"]) - chat_box.insert_msg(Markdown(d["text"], expanded=True)) - chat_box.update_msg(Markdown(d["text"]), element_index=element_index) if os.path.exists("tmp/image.jpg"): with open("tmp/image.jpg", "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode()