webui 支持文生图

This commit is contained in:
liunux4odoo 2024-01-13 13:54:03 +08:00
parent 17ba487074
commit 8063aab7a1
9 changed files with 139 additions and 20 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
*.log.*
*.bak
logs
/media/
/knowledge_base/*
!/knowledge_base/samples
/knowledge_base/samples/vector_store

View File

@ -23,6 +23,14 @@ LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH)
# 模型生成内容(图片、视频、音频等)保存位置
MEDIA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "media")
if not os.path.exists(MEDIA_PATH):
os.mkdir(MEDIA_PATH)
os.mkdir(os.path.join(MEDIA_PATH, "image"))
os.mkdir(os.path.join(MEDIA_PATH, "audio"))
os.mkdir(os.path.join(MEDIA_PATH, "video"))
# 临时文件目录,主要用于文件对话
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
if os.path.isdir(BASE_TEMP_DIR):

View File

@ -1,9 +1,17 @@
import base64
import json
import os
from PIL import Image
from typing import List
import uuid
from langchain.agents import tool
from langchain.pydantic_v1 import Field
import openai
from pydantic.fields import FieldInfo
from configs.basic_config import MEDIA_PATH
from server.utils import MsgType
def get_image_model_config() -> dict:
@ -18,7 +26,7 @@ def get_image_model_config() -> dict:
return config
@tool
@tool(return_direct=True)
def text2images(
prompt: str,
n: int = Field(1, description="需生成图片的数量"),
@ -26,6 +34,14 @@ def text2images(
height: int = Field(512, description="生成图片的高度"),
) -> List[str]:
'''根据用户的描述生成图片'''
# workaround before langchain uprading
if isinstance(n, FieldInfo):
n = n.default
if isinstance(width, FieldInfo):
width = width.default
if isinstance(height, FieldInfo):
height = height.default
model_config = get_image_model_config()
assert model_config is not None, "请正确配置文生图模型"
@ -37,13 +53,21 @@ def text2images(
resp = client.images.generate(prompt=prompt,
n=n,
size=f"{width}*{height}",
response_format="url",
response_format="b64_json",
model=model_config["model_name"],
)
return [x.url for x in resp.data]
images = []
for x in resp.data:
uid = uuid.uuid4().hex
filename = f"image/{uid}.png"
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
fp.write(base64.b64decode(x.b64_json))
images.append(filename)
return json.dumps({"message_type": MsgType.IMAGE, "images": images})
if __name__ == "__main__":
from io import BytesIO
from matplotlib import pyplot as plt
from pathlib import Path
import sys
@ -54,5 +78,7 @@ if __name__ == "__main__":
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
print(params)
image = text2images.invoke(params)[0]
buffer = BytesIO(base64.b64decode(image))
image = Image.open(buffer)
plt.imshow(image)
plt.show()

View File

@ -4,13 +4,14 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs import VERSION
from configs import VERSION, MEDIA_PATH
from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.responses import RedirectResponse
from server.chat.chat import chat
from server.chat.completion import completion
@ -125,6 +126,9 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
summary="将文本向量化,支持本地模型和在线模型",
)(embed_texts_endpoint)
# 媒体文件
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
def mount_knowledge_routes(app: FastAPI):
from server.chat.file_chat import upload_temp_docs, file_chat

View File

@ -14,11 +14,11 @@ from server.agent.agent_factory.agents_registry import agents_registry
from server.agent.tools_factory.tools_registry import all_tools
from server.agent.container import container
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType
from server.chat.utils import History
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
from server.db.repository import add_message_to_db
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
def create_models_from_config(configs, callbacks, stream):
@ -63,10 +63,11 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
llm=models["llm_model"]
llm.callbacks = callbacks
chain = LLMChain(
prompt=chat_prompt,
llm=models["llm_model"],
callbacks=callbacks,
llm=llm,
memory=memory
)
classifier_chain = (
@ -77,7 +78,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
if "action_model" in models and tools is not None:
agent_executor = agents_registry(
llm=models["action_model"],
llm=llm,
callbacks=callbacks,
tools=tools,
prompt=None,
@ -144,6 +145,21 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
async for chunk in callback.aiter():
data = json.loads(chunk)
if data["status"] == AgentStatus.tool_end:
try:
tool_output = json.loads(data["tool_output"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...
elif data["status"] == AgentStatus.agent_finish:
try:
tool_output = json.loads(data["text"])
if message_type := tool_output.get("message_type"):
data["message_type"] = message_type
except:
...
data.setdefault("message_type", MsgType.TEXT)
data["message_id"] = message_id
yield json.dumps(data, ensure_ascii=False)

View File

@ -1,3 +1,4 @@
from functools import lru_cache
from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose

View File

@ -101,6 +101,13 @@ def get_OpenAI(
return model
class MsgType:
TEXT = 1
IMAGE = 2
AUDIO = 3
VIDEO = 4
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")

View File

@ -3,6 +3,8 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import asyncio
import json
from pprint import pprint
from langchain.agents import AgentExecutor
from langchain_openai.chat_models import ChatOpenAI
# from langchain.chat_models.openai import ChatOpenAI
@ -12,8 +14,8 @@ from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_ag
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
from langchain import globals
globals.set_debug(True)
globals.set_verbose(True)
# globals.set_debug(True)
# globals.set_verbose(True)
async def test1():
@ -39,7 +41,7 @@ async def test1():
await ret
async def test2():
async def test_server_chat():
from server.chat.chat import chat
mc={'preprocess_model': {
@ -83,7 +85,57 @@ async def test2():
history_len=-1,
history=[],
stream=True)).body_iterator:
print(x)
pprint(x)
asyncio.run(test2())
async def test_text2image():
from server.chat.chat import chat
mc={'preprocess_model': {
'qwen-api': {
'temperature': 0.4,
'max_tokens': 2048,
'history_len': 100,
'prompt_name': 'default',
'callbacks': False}
},
'llm_model': {
'qwen-api': {
'temperature': 0.9,
'max_tokens': 4096,
'history_len': 3,
'prompt_name': 'default',
'callbacks': True}
},
'action_model': {
'qwen-api': {
'temperature': 0.01,
'max_tokens': 4096,
'prompt_name': 'qwen',
'callbacks': True}
},
'postprocess_model': {
'qwen-api': {
'temperature': 0.01,
'max_tokens': 4096,
'prompt_name': 'default',
'callbacks': True}
},
'image_model': {
'sd-turbo': {}
}
}
tc={'text2images': {'use': True}}
async for x in (await chat("draw a house",{},
model_config=mc,
tool_config=tc,
conversation_id=None,
history_len=-1,
history=[],
stream=False)).body_iterator:
x = json.loads(x)
pprint(x)
asyncio.run(test_text2image())

View File

@ -12,6 +12,7 @@ import re
import time
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
from server.callback_handler.agent_callback_handler import AgentStatus
from server.utils import MsgType
import uuid
from typing import List, Dict
@ -168,7 +169,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
for k, v in config_models.get("online", {}).items():
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
available_models.append(k)
llm_models = running_models + available_models # + ["openai-api"]
llm_models = running_models + available_models
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
if cur_llm_model in llm_models:
index = llm_models.index(cur_llm_model)
@ -276,7 +277,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text = ""
text_action = ""
element_index = 0
for d in api.chat_chat(query=prompt,
metadata=files_upload,
history=history,
@ -288,6 +288,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
metadata = {
"message_id": message_id,
}
print(d)
if d["status"] == AgentStatus.error:
st.error(d["text"])
elif d["status"] == AgentStatus.agent_action:
@ -310,11 +311,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
elif d["status"] == AgentStatus.llm_end:
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
elif d["status"] == AgentStatus.agent_finish:
element_index += 1
if d["message_type"] == MsgType.IMAGE:
for url in json.loads(d["text"]).get("images", []):
url = f"{api.base_url}/media/{url}"
chat_box.insert_msg(Image(url))
chat_box.update_msg(element_index=element_index, expanded=False, state="complete")
else:
chat_box.insert_msg(Markdown(d["text"], expanded=True))
# print(d["text"])
chat_box.insert_msg(Markdown(d["text"], expanded=True))
chat_box.update_msg(Markdown(d["text"]), element_index=element_index)
if os.path.exists("tmp/image.jpg"):
with open("tmp/image.jpg", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode()