webui 支持文生图

This commit is contained in:
liunux4odoo 2024-01-13 13:54:03 +08:00
parent 17ba487074
commit 8063aab7a1
9 changed files with 139 additions and 20 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
*.log.* *.log.*
*.bak *.bak
logs logs
/media/
/knowledge_base/* /knowledge_base/*
!/knowledge_base/samples !/knowledge_base/samples
/knowledge_base/samples/vector_store /knowledge_base/samples/vector_store

View File

@ -23,6 +23,14 @@ LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
if not os.path.exists(LOG_PATH): if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH) os.mkdir(LOG_PATH)
# 模型生成内容(图片、视频、音频等)保存位置
MEDIA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "media")
if not os.path.exists(MEDIA_PATH):
os.mkdir(MEDIA_PATH)
os.mkdir(os.path.join(MEDIA_PATH, "image"))
os.mkdir(os.path.join(MEDIA_PATH, "audio"))
os.mkdir(os.path.join(MEDIA_PATH, "video"))
# 临时文件目录,主要用于文件对话 # 临时文件目录,主要用于文件对话
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat") BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
if os.path.isdir(BASE_TEMP_DIR): if os.path.isdir(BASE_TEMP_DIR):

View File

@ -1,9 +1,17 @@
import base64
import json
import os
from PIL import Image from PIL import Image
from typing import List from typing import List
import uuid
from langchain.agents import tool from langchain.agents import tool
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
import openai import openai
from pydantic.fields import FieldInfo
from configs.basic_config import MEDIA_PATH
from server.utils import MsgType
def get_image_model_config() -> dict: def get_image_model_config() -> dict:
@ -18,7 +26,7 @@ def get_image_model_config() -> dict:
return config return config
@tool @tool(return_direct=True)
def text2images( def text2images(
prompt: str, prompt: str,
n: int = Field(1, description="需生成图片的数量"), n: int = Field(1, description="需生成图片的数量"),
@ -26,6 +34,14 @@ def text2images(
height: int = Field(512, description="生成图片的高度"), height: int = Field(512, description="生成图片的高度"),
) -> List[str]: ) -> List[str]:
'''根据用户的描述生成图片''' '''根据用户的描述生成图片'''
# workaround before langchain uprading
if isinstance(n, FieldInfo):
n = n.default
if isinstance(width, FieldInfo):
width = width.default
if isinstance(height, FieldInfo):
height = height.default
model_config = get_image_model_config() model_config = get_image_model_config()
assert model_config is not None, "请正确配置文生图模型" assert model_config is not None, "请正确配置文生图模型"
@ -37,13 +53,21 @@ def text2images(
resp = client.images.generate(prompt=prompt, resp = client.images.generate(prompt=prompt,
n=n, n=n,
size=f"{width}*{height}", size=f"{width}*{height}",
response_format="url", response_format="b64_json",
model=model_config["model_name"], model=model_config["model_name"],
) )
return [x.url for x in resp.data] images = []
for x in resp.data:
uid = uuid.uuid4().hex
filename = f"image/{uid}.png"
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json))
images.append(filename)
return json.dumps({"message_type": MsgType.IMAGE, "images": images})
if __name__ == "__main__": if __name__ == "__main__":
from io import BytesIO
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from pathlib import Path from pathlib import Path
import sys import sys
@ -54,5 +78,7 @@ if __name__ == "__main__":
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict() params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
print(params) print(params)
image = text2images.invoke(params)[0] image = text2images.invoke(params)[0]
buffer = BytesIO(base64.b64decode(image))
image = Image.open(buffer)
plt.imshow(image) plt.imshow(image)
plt.show() plt.show()

View File

@ -4,13 +4,14 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs import VERSION from configs import VERSION, MEDIA_PATH
from configs.model_config import NLTK_DATA_PATH from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN from configs.server_config import OPEN_CROSS_DOMAIN
import argparse import argparse
import uvicorn import uvicorn
from fastapi import Body from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from server.chat.chat import chat from server.chat.chat import chat
from server.chat.completion import completion from server.chat.completion import completion
@ -125,6 +126,9 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
summary="将文本向量化,支持本地模型和在线模型", summary="将文本向量化,支持本地模型和在线模型",
)(embed_texts_endpoint) )(embed_texts_endpoint)
# 媒体文件
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
def mount_knowledge_routes(app: FastAPI): def mount_knowledge_routes(app: FastAPI):
from server.chat.file_chat import upload_temp_docs, file_chat from server.chat.file_chat import upload_temp_docs, file_chat

View File

@ -14,11 +14,11 @@ from server.agent.agent_factory.agents_registry import agents_registry
from server.agent.tools_factory.tools_registry import all_tools from server.agent.tools_factory.tools_registry import all_tools
from server.agent.container import container from server.agent.container import container
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType
from server.chat.utils import History from server.chat.utils import History
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
def create_models_from_config(configs, callbacks, stream): def create_models_from_config(configs, callbacks, stream):
@ -63,10 +63,11 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False) input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg]) chat_prompt = ChatPromptTemplate.from_messages([input_msg])
llm=models["llm_model"]
llm.callbacks = callbacks
chain = LLMChain( chain = LLMChain(
prompt=chat_prompt, prompt=chat_prompt,
llm=models["llm_model"], llm=llm,
callbacks=callbacks,
memory=memory memory=memory
) )
classifier_chain = ( classifier_chain = (
@ -77,7 +78,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
if "action_model" in models and tools is not None: if "action_model" in models and tools is not None:
agent_executor = agents_registry( agent_executor = agents_registry(
llm=models["action_model"], llm=llm,
callbacks=callbacks, callbacks=callbacks,
tools=tools, tools=tools,
prompt=None, prompt=None,
@ -144,6 +145,21 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
async for chunk in callback.aiter(): async for chunk in callback.aiter():
data = json.loads(chunk) data = json.loads(chunk)
if data["status"] == AgentStatus.tool_end:
try:
tool_output = json.loads(data["tool_output"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...
elif data["status"] == AgentStatus.agent_finish:
try:
tool_output = json.loads(data["text"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...
data.setdefault("message_type", MsgType.TEXT)
data["message_id"] = message_id data["message_id"] = message_id
yield json.dumps(data, ensure_ascii=False) yield json.dumps(data, ensure_ascii=False)

View File

@ -1,3 +1,4 @@
from functools import lru_cache
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose from configs import logger, log_verbose

View File

@ -101,6 +101,13 @@ def get_OpenAI(
return model return model
class MsgType:
TEXT = 1
IMAGE = 2
AUDIO = 3
VIDEO = 4
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code") code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message") msg: str = pydantic.Field("success", description="API status message")

View File

@ -3,6 +3,8 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent)) sys.path.append(str(Path(__file__).parent.parent))
import asyncio import asyncio
import json
from pprint import pprint
from langchain.agents import AgentExecutor from langchain.agents import AgentExecutor
from langchain_openai.chat_models import ChatOpenAI from langchain_openai.chat_models import ChatOpenAI
# from langchain.chat_models.openai import ChatOpenAI # from langchain.chat_models.openai import ChatOpenAI
@ -12,8 +14,8 @@ from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_ag
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
from langchain import globals from langchain import globals
globals.set_debug(True) # globals.set_debug(True)
globals.set_verbose(True) # globals.set_verbose(True)
async def test1(): async def test1():
@ -39,7 +41,7 @@ async def test1():
await ret await ret
async def test2(): async def test_server_chat():
from server.chat.chat import chat from server.chat.chat import chat
mc={'preprocess_model': { mc={'preprocess_model': {
@ -83,7 +85,57 @@ async def test2():
history_len=-1, history_len=-1,
history=[], history=[],
stream=True)).body_iterator: stream=True)).body_iterator:
print(x) pprint(x)
asyncio.run(test2()) async def test_text2image():
from server.chat.chat import chat
mc={'preprocess_model': {
'qwen-api': {
'temperature': 0.4,
'max_tokens': 2048,
'history_len': 100,
'prompt_name': 'default',
'callbacks': False}
},
'llm_model': {
'qwen-api': {
'temperature': 0.9,
'max_tokens': 4096,
'history_len': 3,
'prompt_name': 'default',
'callbacks': True}
},
'action_model': {
'qwen-api': {
'temperature': 0.01,
'max_tokens': 4096,
'prompt_name': 'qwen',
'callbacks': True}
},
'postprocess_model': {
'qwen-api': {
'temperature': 0.01,
'max_tokens': 4096,
'prompt_name': 'default',
'callbacks': True}
},
'image_model': {
'sd-turbo': {}
}
}
tc={'text2images': {'use': True}}
async for x in (await chat("draw a house",{},
model_config=mc,
tool_config=tc,
conversation_id=None,
history_len=-1,
history=[],
stream=False)).body_iterator:
x = json.loads(x)
pprint(x)
asyncio.run(test_text2image())

View File

@ -12,6 +12,7 @@ import re
import time import time
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG) from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
from server.callback_handler.agent_callback_handler import AgentStatus from server.callback_handler.agent_callback_handler import AgentStatus
from server.utils import MsgType
import uuid import uuid
from typing import List, Dict from typing import List, Dict
@ -168,7 +169,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
for k, v in config_models.get("online", {}).items(): for k, v in config_models.get("online", {}).items():
if not v.get("provider") and k not in running_models and k in LLM_MODELS: if not v.get("provider") and k not in running_models and k in LLM_MODELS:
available_models.append(k) available_models.append(k)
llm_models = running_models + available_models # + ["openai-api"] llm_models = running_models + available_models
cur_llm_model = st.session_state.get("cur_llm_model", default_model) cur_llm_model = st.session_state.get("cur_llm_model", default_model)
if cur_llm_model in llm_models: if cur_llm_model in llm_models:
index = llm_models.index(cur_llm_model) index = llm_models.index(cur_llm_model)
@ -276,7 +277,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text = "" text = ""
text_action = "" text_action = ""
element_index = 0 element_index = 0
for d in api.chat_chat(query=prompt, for d in api.chat_chat(query=prompt,
metadata=files_upload, metadata=files_upload,
history=history, history=history,
@ -288,6 +288,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
metadata = { metadata = {
"message_id": message_id, "message_id": message_id,
} }
print(d)
if d["status"] == AgentStatus.error: if d["status"] == AgentStatus.error:
st.error(d["text"]) st.error(d["text"])
elif d["status"] == AgentStatus.agent_action: elif d["status"] == AgentStatus.agent_action:
@ -310,11 +311,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
elif d["status"] == AgentStatus.llm_end: elif d["status"] == AgentStatus.llm_end:
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata) chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.agent_finish: elif d["status"] == AgentStatus.agent_finish:
element_index += 1 if d["message_type"] == MsgType.IMAGE:
for url in json.loads(d["text"]).get("images", []):
url = f"{api.base_url}/media/{url}"
chat_box.insert_msg(Image(url))
chat_box.update_msg(element_index=element_index, expanded=False, state="complete")
else:
chat_box.insert_msg(Markdown(d["text"], expanded=True))
# print(d["text"])
chat_box.insert_msg(Markdown(d["text"], expanded=True))
chat_box.update_msg(Markdown(d["text"]), element_index=element_index)
if os.path.exists("tmp/image.jpg"): if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file: with open("tmp/image.jpg", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode() encoded_string = base64.b64encode(image_file.read()).decode()