mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-30 18:56:23 +08:00
webui 支持文生图
This commit is contained in:
parent
17ba487074
commit
8063aab7a1
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
||||
*.log.*
|
||||
*.bak
|
||||
logs
|
||||
/media/
|
||||
/knowledge_base/*
|
||||
!/knowledge_base/samples
|
||||
/knowledge_base/samples/vector_store
|
||||
|
||||
@ -23,6 +23,14 @@ LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||
if not os.path.exists(LOG_PATH):
|
||||
os.mkdir(LOG_PATH)
|
||||
|
||||
# 模型生成内容(图片、视频、音频等)保存位置
|
||||
MEDIA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "media")
|
||||
if not os.path.exists(MEDIA_PATH):
|
||||
os.mkdir(MEDIA_PATH)
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "image"))
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "audio"))
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "video"))
|
||||
|
||||
# 临时文件目录,主要用于文件对话
|
||||
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
|
||||
if os.path.isdir(BASE_TEMP_DIR):
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from PIL import Image
|
||||
from typing import List
|
||||
import uuid
|
||||
|
||||
from langchain.agents import tool
|
||||
from langchain.pydantic_v1 import Field
|
||||
import openai
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from configs.basic_config import MEDIA_PATH
|
||||
from server.utils import MsgType
|
||||
|
||||
|
||||
def get_image_model_config() -> dict:
|
||||
@ -18,7 +26,7 @@ def get_image_model_config() -> dict:
|
||||
return config
|
||||
|
||||
|
||||
@tool
|
||||
@tool(return_direct=True)
|
||||
def text2images(
|
||||
prompt: str,
|
||||
n: int = Field(1, description="需生成图片的数量"),
|
||||
@ -26,6 +34,14 @@ def text2images(
|
||||
height: int = Field(512, description="生成图片的高度"),
|
||||
) -> List[str]:
|
||||
'''根据用户的描述生成图片'''
|
||||
# workaround before langchain uprading
|
||||
if isinstance(n, FieldInfo):
|
||||
n = n.default
|
||||
if isinstance(width, FieldInfo):
|
||||
width = width.default
|
||||
if isinstance(height, FieldInfo):
|
||||
height = height.default
|
||||
|
||||
model_config = get_image_model_config()
|
||||
assert model_config is not None, "请正确配置文生图模型"
|
||||
|
||||
@ -37,13 +53,21 @@ def text2images(
|
||||
resp = client.images.generate(prompt=prompt,
|
||||
n=n,
|
||||
size=f"{width}*{height}",
|
||||
response_format="url",
|
||||
response_format="b64_json",
|
||||
model=model_config["model_name"],
|
||||
)
|
||||
return [x.url for x in resp.data]
|
||||
images = []
|
||||
for x in resp.data:
|
||||
uid = uuid.uuid4().hex
|
||||
filename = f"image/{uid}.png"
|
||||
with open(os.path.join(MEDIA_PATH, filename), "wb") as fp:
|
||||
fp.write(base64.b64decode(x.b64_json))
|
||||
images.append(filename)
|
||||
return json.dumps({"message_type": MsgType.IMAGE, "images": images})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from io import BytesIO
|
||||
from matplotlib import pyplot as plt
|
||||
from pathlib import Path
|
||||
import sys
|
||||
@ -54,5 +78,7 @@ if __name__ == "__main__":
|
||||
params = text2images.args_schema.parse_obj({"prompt": prompt}).dict()
|
||||
print(params)
|
||||
image = text2images.invoke(params)[0]
|
||||
buffer = BytesIO(base64.b64decode(image))
|
||||
image = Image.open(buffer)
|
||||
plt.imshow(image)
|
||||
plt.show()
|
||||
|
||||
@ -4,13 +4,14 @@ import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from configs import VERSION
|
||||
from configs import VERSION, MEDIA_PATH
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat.chat import chat
|
||||
from server.chat.completion import completion
|
||||
@ -125,6 +126,9 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
||||
summary="将文本向量化,支持本地模型和在线模型",
|
||||
)(embed_texts_endpoint)
|
||||
|
||||
# 媒体文件
|
||||
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
|
||||
|
||||
|
||||
def mount_knowledge_routes(app: FastAPI):
|
||||
from server.chat.file_chat import upload_temp_docs, file_chat
|
||||
|
||||
@ -14,11 +14,11 @@ from server.agent.agent_factory.agents_registry import agents_registry
|
||||
from server.agent.tools_factory.tools_registry import all_tools
|
||||
from server.agent.container import container
|
||||
|
||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template, MsgType
|
||||
from server.chat.utils import History
|
||||
from server.memory.conversation_db_buffer_memory import ConversationBufferDBMemory
|
||||
from server.db.repository import add_message_to_db
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler, AgentStatus
|
||||
|
||||
|
||||
def create_models_from_config(configs, callbacks, stream):
|
||||
@ -63,10 +63,11 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
input_msg = History(role="user", content=prompts["llm_model"]).to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
|
||||
|
||||
llm=models["llm_model"]
|
||||
llm.callbacks = callbacks
|
||||
chain = LLMChain(
|
||||
prompt=chat_prompt,
|
||||
llm=models["llm_model"],
|
||||
callbacks=callbacks,
|
||||
llm=llm,
|
||||
memory=memory
|
||||
)
|
||||
classifier_chain = (
|
||||
@ -77,7 +78,7 @@ def create_models_chains(history, history_len, prompts, models, tools, callbacks
|
||||
|
||||
if "action_model" in models and tools is not None:
|
||||
agent_executor = agents_registry(
|
||||
llm=models["action_model"],
|
||||
llm=llm,
|
||||
callbacks=callbacks,
|
||||
tools=tools,
|
||||
prompt=None,
|
||||
@ -144,6 +145,21 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
|
||||
async for chunk in callback.aiter():
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == AgentStatus.tool_end:
|
||||
try:
|
||||
tool_output = json.loads(data["tool_output"])
|
||||
if message_type := tool_output.get("message_type"):
|
||||
data["message_type"] = message_type
|
||||
except:
|
||||
...
|
||||
elif data["status"] == AgentStatus.agent_finish:
|
||||
try:
|
||||
tool_output = json.loads(data["text"])
|
||||
if message_type := tool_output.get("message_type"):
|
||||
data["message_type"] = message_type
|
||||
except:
|
||||
...
|
||||
data.setdefault("message_type", MsgType.TEXT)
|
||||
data["message_id"] = message_id
|
||||
yield json.dumps(data, ensure_ascii=False)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from functools import lru_cache
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||
from configs import logger, log_verbose
|
||||
|
||||
@ -101,6 +101,13 @@ def get_OpenAI(
|
||||
return model
|
||||
|
||||
|
||||
class MsgType:
|
||||
TEXT = 1
|
||||
IMAGE = 2
|
||||
AUDIO = 3
|
||||
VIDEO = 4
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="API status code")
|
||||
msg: str = pydantic.Field("success", description="API status message")
|
||||
|
||||
@ -3,6 +3,8 @@ from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pprint import pprint
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain_openai.chat_models import ChatOpenAI
|
||||
# from langchain.chat_models.openai import ChatOpenAI
|
||||
@ -12,8 +14,8 @@ from server.agent.agent_factory.qwen_agent import create_structured_qwen_chat_ag
|
||||
from server.callback_handler.agent_callback_handler import AgentExecutorAsyncIteratorCallbackHandler
|
||||
from langchain import globals
|
||||
|
||||
globals.set_debug(True)
|
||||
globals.set_verbose(True)
|
||||
# globals.set_debug(True)
|
||||
# globals.set_verbose(True)
|
||||
|
||||
|
||||
async def test1():
|
||||
@ -39,7 +41,7 @@ async def test1():
|
||||
await ret
|
||||
|
||||
|
||||
async def test2():
|
||||
async def test_server_chat():
|
||||
from server.chat.chat import chat
|
||||
|
||||
mc={'preprocess_model': {
|
||||
@ -83,7 +85,57 @@ async def test2():
|
||||
history_len=-1,
|
||||
history=[],
|
||||
stream=True)).body_iterator:
|
||||
print(x)
|
||||
pprint(x)
|
||||
|
||||
|
||||
asyncio.run(test2())
|
||||
async def test_text2image():
|
||||
from server.chat.chat import chat
|
||||
|
||||
mc={'preprocess_model': {
|
||||
'qwen-api': {
|
||||
'temperature': 0.4,
|
||||
'max_tokens': 2048,
|
||||
'history_len': 100,
|
||||
'prompt_name': 'default',
|
||||
'callbacks': False}
|
||||
},
|
||||
'llm_model': {
|
||||
'qwen-api': {
|
||||
'temperature': 0.9,
|
||||
'max_tokens': 4096,
|
||||
'history_len': 3,
|
||||
'prompt_name': 'default',
|
||||
'callbacks': True}
|
||||
},
|
||||
'action_model': {
|
||||
'qwen-api': {
|
||||
'temperature': 0.01,
|
||||
'max_tokens': 4096,
|
||||
'prompt_name': 'qwen',
|
||||
'callbacks': True}
|
||||
},
|
||||
'postprocess_model': {
|
||||
'qwen-api': {
|
||||
'temperature': 0.01,
|
||||
'max_tokens': 4096,
|
||||
'prompt_name': 'default',
|
||||
'callbacks': True}
|
||||
},
|
||||
'image_model': {
|
||||
'sd-turbo': {}
|
||||
}
|
||||
}
|
||||
|
||||
tc={'text2images': {'use': True}}
|
||||
|
||||
async for x in (await chat("draw a house",{},
|
||||
model_config=mc,
|
||||
tool_config=tc,
|
||||
conversation_id=None,
|
||||
history_len=-1,
|
||||
history=[],
|
||||
stream=False)).body_iterator:
|
||||
x = json.loads(x)
|
||||
pprint(x)
|
||||
|
||||
asyncio.run(test_text2image())
|
||||
|
||||
@ -12,6 +12,7 @@ import re
|
||||
import time
|
||||
from configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, TOOL_CONFIG)
|
||||
from server.callback_handler.agent_callback_handler import AgentStatus
|
||||
from server.utils import MsgType
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
|
||||
@ -168,7 +169,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
for k, v in config_models.get("online", {}).items():
|
||||
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
||||
available_models.append(k)
|
||||
llm_models = running_models + available_models # + ["openai-api"]
|
||||
llm_models = running_models + available_models
|
||||
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
||||
if cur_llm_model in llm_models:
|
||||
index = llm_models.index(cur_llm_model)
|
||||
@ -276,7 +277,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
text = ""
|
||||
text_action = ""
|
||||
element_index = 0
|
||||
|
||||
for d in api.chat_chat(query=prompt,
|
||||
metadata=files_upload,
|
||||
history=history,
|
||||
@ -288,6 +288,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
print(d)
|
||||
if d["status"] == AgentStatus.error:
|
||||
st.error(d["text"])
|
||||
elif d["status"] == AgentStatus.agent_action:
|
||||
@ -310,11 +311,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
elif d["status"] == AgentStatus.llm_end:
|
||||
chat_box.update_msg(text, streaming=False, element_index=element_index, metadata=metadata)
|
||||
elif d["status"] == AgentStatus.agent_finish:
|
||||
element_index += 1
|
||||
if d["message_type"] == MsgType.IMAGE:
|
||||
for url in json.loads(d["text"]).get("images", []):
|
||||
url = f"{api.base_url}/media/{url}"
|
||||
chat_box.insert_msg(Image(url))
|
||||
chat_box.update_msg(element_index=element_index, expanded=False, state="complete")
|
||||
else:
|
||||
chat_box.insert_msg(Markdown(d["text"], expanded=True))
|
||||
|
||||
# print(d["text"])
|
||||
chat_box.insert_msg(Markdown(d["text"], expanded=True))
|
||||
chat_box.update_msg(Markdown(d["text"]), element_index=element_index)
|
||||
if os.path.exists("tmp/image.jpg"):
|
||||
with open("tmp/image.jpg", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user