- 新功能 (#3944)

- streamlit 更新到 1.34,webui 支持 Dialog 操作
    - streamlit-chatbox 更新到 1.1.12,更好的多会话支持
- 开发者
    - 在 API 中增加项目图片路由(/img/{file_name}),方便前端使用
This commit is contained in:
liunux4odoo 2024-05-06 09:09:56 +08:00 committed by GitHub
parent 98cb1aaf79
commit 9fddd07afd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 343 additions and 309 deletions

View File

@ -11,8 +11,11 @@ langchain.verbose = False
# 通常情况下不需要更改以下内容 # 通常情况下不需要更改以下内容
# chatchat 项目根目录
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
# 用户数据根目录 # 用户数据根目录
DATA_PATH = str(Path(__file__).absolute().parent.parent / "data") DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
if not os.path.exists(DATA_PATH): if not os.path.exists(DATA_PATH):
os.mkdir(DATA_PATH) os.mkdir(DATA_PATH)

View File

@ -19,6 +19,10 @@ PROMPT_TEMPLATES = {
'{history}\n' '{history}\n'
'Human: {input}\n' 'Human: {input}\n'
'AI:', 'AI:',
"rag":
'【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。\n\n'
'【已知信息】{context}\n\n'
'【问题】{question}\n',
}, },
"action_model": { "action_model": {
"GPT-4": "GPT-4":

View File

@ -16,7 +16,7 @@ from openai.types.file_object import FileObject
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from .api_schemas import * from .api_schemas import *
from chatchat.configs import logger, BASE_TEMP_DIR from chatchat.configs import logger, BASE_TEMP_DIR, log_verbose
from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient
@ -126,6 +126,8 @@ async def create_chat_completions(
request: Request, request: Request,
body: OpenAIChatInput, body: OpenAIChatInput,
): ):
if log_verbose:
print(body)
async with get_model_client(body.model) as client: async with get_model_client(body.model) as client:
result = await openai_request(client.chat.completions.create, body) result = await openai_request(client.chat.completions.create, body)
return result return result

View File

@ -1,4 +1,5 @@
import argparse import argparse
import os
from typing import Literal from typing import Literal
from fastapi import FastAPI, Body from fastapi import FastAPI, Body
@ -7,7 +8,7 @@ from fastapi.staticfiles import StaticFiles
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
import uvicorn import uvicorn
from chatchat.configs import VERSION, MEDIA_PATH from chatchat.configs import VERSION, MEDIA_PATH, CHATCHAT_ROOT
from chatchat.configs.server_config import OPEN_CROSS_DOMAIN from chatchat.configs.server_config import OPEN_CROSS_DOMAIN
from chatchat.server.api_server.chat_routes import chat_router from chatchat.server.api_server.chat_routes import chat_router
from chatchat.server.api_server.kb_routes import kb_router from chatchat.server.api_server.kb_routes import kb_router
@ -55,6 +56,10 @@ def create_app(run_mode: str=None):
# 媒体文件 # 媒体文件
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
# 项目相关图片
img_dir = os.path.join(CHATCHAT_ROOT, "img")
app.mount("/img", StaticFiles(directory=img_dir), name="img")
return app return app

View File

@ -202,6 +202,16 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
message_id=message_id, message_id=message_id,
) )
yield ret.model_dump_json() yield ret.model_dump_json()
yield OpenAIChatOutput( # return blank text lastly
id=f"chat{uuid.uuid4()}",
object="chat.completion.chunk",
content="",
role="assistant",
model=models["llm_model"].model_name,
status = data["status"],
message_type = data["message_type"],
message_id=message_id,
)
await task await task
if stream: if stream:

View File

@ -664,9 +664,6 @@ def get_httpx_client(
# construct Client # construct Client
kwargs.update(timeout=timeout, proxies=default_proxies) kwargs.update(timeout=timeout, proxies=default_proxies)
if log_verbose:
logger.info(f'{get_httpx_client.__class__.__name__}:kwargs: {kwargs}')
if use_async: if use_async:
return httpx.AsyncClient(**kwargs) return httpx.AsyncClient(**kwargs)
else: else:

View File

@ -1,30 +1,23 @@
import streamlit as st
# from chatchat.webui_pages.loom_view_client import update_store
# from chatchat.webui_pages.openai_plugins import openai_plugins_page
from chatchat.webui_pages.utils import *
from streamlit_option_menu import option_menu
from chatchat.webui_pages.dialogue.dialogue import dialogue_page, chat_box
from chatchat.webui_pages.knowledge_base.knowledge_base import knowledge_base_page
import os
import sys import sys
import streamlit as st
import streamlit_antd_components as sac
from chatchat.configs import VERSION from chatchat.configs import VERSION
from chatchat.server.utils import api_address from chatchat.server.utils import api_address
from chatchat.webui_pages.utils import *
from chatchat.webui_pages.dialogue.dialogue import dialogue_page, chat_box
from chatchat.webui_pages.knowledge_base.knowledge_base import knowledge_base_page
# def on_change(key):
# if key:
# update_store()
img_dir = os.path.dirname(os.path.abspath(__file__))
api = ApiRequest(base_url=api_address()) api = ApiRequest(base_url=api_address())
if __name__ == "__main__": if __name__ == "__main__":
is_lite = "lite" in sys.argv is_lite = "lite" in sys.argv # TODO: remove lite mode
st.set_page_config( st.set_page_config(
"Langchain-Chatchat WebUI", "Langchain-Chatchat WebUI",
os.path.join(img_dir, "img", "chatchat_icon_blue_square_v2.png"), get_img_url("chatchat_icon_blue_square_v2.png"),
initial_sidebar_state="expanded", initial_sidebar_state="expanded",
menu_items={ menu_items={
'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat', 'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat',
@ -32,66 +25,47 @@ if __name__ == "__main__":
'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}""" 'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}"""
}, },
layout="wide" layout="wide"
) )
# use the following code to set the app to wide mode and the html markdown to increase the sidebar width # use the following code to set the app to wide mode and the html markdown to increase the sidebar width
st.markdown( st.markdown(
""" """
<style> <style>
[data-testid="stSidebar"][aria-expanded="true"] > div:first-child{ [data-testid="stSidebarUserContent"] {
width: 350px; padding-top: 20px;
} }
[data-testid="stSidebar"][aria-expanded="false"] > div:first-child{ .block-container {
width: 600px; padding-top: 25px;
margin-left: -600px; }
[data-testid="stBottomBlockContainer"] {
padding-bottom: 20px;
} }
""", """,
unsafe_allow_html=True, unsafe_allow_html=True,
) )
pages = {
"对话": {
"icon": "chat",
"func": dialogue_page,
},
"知识库管理": {
"icon": "hdd-stack",
"func": knowledge_base_page,
},
# "模型服务": {
# "icon": "hdd-stack",
# "func": openai_plugins_page,
# },
}
# 更新状态
# if "status" not in st.session_state \
# or "run_plugins_list" not in st.session_state \
# or "launch_subscribe_info" not in st.session_state \
# or "list_running_models" not in st.session_state \
# or "model_config" not in st.session_state:
# update_store()
with st.sidebar: with st.sidebar:
st.image( st.image(
os.path.join(img_dir, "img", 'logo-long-chatchat-trans-v2.png'), get_img_url('logo-long-chatchat-trans-v2.png'),
use_column_width=True use_column_width=True
) )
st.caption( st.caption(
f"""<p align="right">当前版本:{VERSION}</p>""", f"""<p align="right">当前版本:{VERSION}</p>""",
unsafe_allow_html=True, unsafe_allow_html=True,
) )
options = list(pages)
icons = [x["icon"] for x in pages.values()]
default_index = 0 selected_page = sac.menu(
selected_page = option_menu( [
menu_title="", sac.MenuItem("对话", icon="chat"),
sac.MenuItem("知识库管理", icon="hdd-stack"),
],
key="selected_page", key="selected_page",
options=options, open_index=0
icons=icons,
# menu_icon="chat-quote",
default_index=default_index,
) )
if selected_page in pages: sac.divider()
pages[selected_page]["func"](api=api, is_lite=is_lite)
if selected_page == "知识库管理":
knowledge_base_page(api=api, is_lite=is_lite)
else:
dialogue_page(api=api, is_lite=is_lite)

View File

@ -1,37 +1,46 @@
import base64 import base64
import uuid from datetime import datetime
import os import os
import re import uuid
import time
from typing import List, Dict from typing import List, Dict
import streamlit as st # from audio_recorder_streamlit import audio_recorder
from streamlit_antd_components.utils import ParseItems
import openai import openai
import streamlit as st
import streamlit_antd_components as sac
from streamlit_chatbox import * from streamlit_chatbox import *
from streamlit_modal import Modal from streamlit_extras.bottom_container import bottom
from datetime import datetime
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS) from chatchat.configs import (LLM_MODEL_CONFIG, TEMPERATURE, MODEL_PLATFORMS, DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL)
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.utils import MsgType, get_config_models from chatchat.server.utils import MsgType, get_config_models
from chatchat.server.utils import get_tool_config
from chatchat.webui_pages.utils import * from chatchat.webui_pages.utils import *
from chatchat.webui_pages.dialogue.utils import process_files from chatchat.webui_pages.dialogue.utils import process_files
img_dir = (Path(__file__).absolute().parent.parent.parent)
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=get_img_url("chatchat_icon_blue_square_v2.png")
img_dir,
"img",
"chatchat_icon_blue_square_v2.png"
)
) )
def save_session():
'''save session state to chat context'''
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def restore_session():
'''restore sesstion state from chat context'''
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def rerun():
'''
save chat context before rerun
'''
save_session()
st.rerun()
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
''' '''
返回消息历史 返回消息历史
@ -49,7 +58,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
"content": "\n\n".join(content), "content": "\n\n".join(content),
} }
return chat_box.filter_history(history_len=history_len, filter=filter) messages = chat_box.filter_history(history_len=history_len, filter=filter)
if sys_msg := st.session_state.get("system_message"):
messages = [{"role": "system", "content": sys_msg}] + messages
return messages
@st.cache_data @st.cache_data
@ -61,80 +73,100 @@ def upload_temp_docs(files, _api: ApiRequest) -> str:
return _api.upload_temp_docs(files).get("data", {}).get("id") return _api.upload_temp_docs(files).get("data", {}).get("id")
def parse_command(text: str, modal: Modal) -> bool: def add_conv(name: str=""):
'''
检查用户是否输入了自定义命令当前支持
/new {session_name}如果未提供名称默认为会话X
/del {session_name}如果未提供名称在会话数量>1的情况下删除当前会话
/clear {session_name}如果未提供名称默认清除当前会话
/stop {session_name}如果未提供名称默认停止当前会话
/help查看命令帮助
返回值输入的是命令返回True否则返回False
'''
if m := re.match(r"/([^\s]+)\s*(.*)", text):
cmd, name = m.groups()
name = name.strip()
conv_names = chat_box.get_chat_names() conv_names = chat_box.get_chat_names()
if cmd == "help":
modal.open()
elif cmd == "new":
if not name: if not name:
i = 1 i = len(conv_names) + 1
while True: while True:
name = f"会话{i}" name = f"会话{i}"
if name not in conv_names: if name not in conv_names:
break break
i += 1 i += 1
if name in st.session_state["conversation_ids"]: if name in conv_names:
st.error(f"该会话名称 “{name}” 已存在") sac.alert("创建新会话出错", f"该会话名称 “{name}” 已存在", color="error", closable=True)
time.sleep(1)
else: else:
st.session_state["conversation_ids"][name] = uuid.uuid4().hex chat_box.use_chat_name(name)
st.session_state["cur_conv_name"] = name st.session_state["cur_conv_name"] = name
elif cmd == "del":
name = name or st.session_state.get("cur_conv_name")
def del_conv(name: str=None):
conv_names = chat_box.get_chat_names()
name = name or chat_box.cur_chat_name
if len(conv_names) == 1: if len(conv_names) == 1:
st.error("这是最后一个会话,无法删除") sac.alert("删除会话出错", f"这是最后一个会话,无法删除", color="error", closable=True)
time.sleep(1) elif not name or name not in conv_names:
elif not name or name not in st.session_state["conversation_ids"]: sac.alert("删除会话出错", f"无效的会话名称:“{name}", color="error", closable=True)
st.error(f"无效的会话名称:“{name}")
time.sleep(1)
else: else:
st.session_state["conversation_ids"].pop(name, None)
chat_box.del_chat_name(name) chat_box.del_chat_name(name)
st.session_state["cur_conv_name"] = "" restore_session()
elif cmd == "clear": st.session_state["cur_conv_name"] = chat_box.cur_chat_name
def clear_conv(name: str=None):
chat_box.reset_history(name=name or None) chat_box.reset_history(name=name or None)
return True
return False
def dialogue_page(api: ApiRequest, is_lite: bool = False): @st.cache_data
st.session_state.setdefault("conversation_ids", {}) def list_tools(_api: ApiRequest):
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex) return _api.list_tools()
st.session_state.setdefault("file_chat_id", None)
# 弹出自定义命令帮助信息
modal = Modal("自定义命令", key="cmd_help", max_width="500") def dialogue_page(
if modal.is_open(): api: ApiRequest,
with modal.container(): is_lite: bool = False,
cmds = [x for x in parse_command.__doc__.split("\n") if x.strip().startswith("/")] ):
st.write("\n\n".join(cmds)) ctx = chat_box.context
ctx.setdefault("uid", uuid.uuid4().hex)
ctx.setdefault("file_chat_id", None)
ctx.setdefault("llm_model", DEFAULT_LLM_MODEL)
ctx.setdefault("temperature", TEMPERATURE)
st.session_state.setdefault("cur_conv_name", chat_box.cur_chat_name)
restore_session()
# st.write(chat_box.cur_chat_name)
# st.write(st.session_state)
@st.experimental_dialog("模型配置", width="large")
def llm_model_setting():
# 模型
cols = st.columns(3)
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
platform = cols[0].selectbox("选择模型平台", platforms, key="platform")
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
llm_model = cols[1].selectbox("选择LLM模型", llm_models, key="llm_model")
temperature = cols[2].slider("Temperature", 0.0, 1.0, key="temperature")
system_message = st.text_area("System Message:", key="system_message")
if st.button("OK"):
rerun()
@st.experimental_dialog("重命名会话")
def rename_conversation():
name = st.text_input("会话名称")
if st.button("OK"):
chat_box.change_chat_name(name)
restore_session()
st.session_state["cur_conv_name"] = name
rerun()
with st.sidebar: with st.sidebar:
tab1, tab2 = st.tabs(["对话设置", "模型设置"]) tab1, tab2 = st.tabs(["工具设置", "会话设置"])
with tab1: with tab1:
use_agent = st.checkbox("启用Agent", True, help="请确保选择的模型具备Agent能力") use_agent = st.checkbox("启用Agent", help="请确保选择的模型具备Agent能力", key="use_agent")
# 选择工具 # 选择工具
tools = api.list_tools() tools = list_tools(api)
tool_names = ["None"] + list(tools)
if use_agent: if use_agent:
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"]) # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools")
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], key="selected_tools")
else: else:
selected_tool = st.selectbox("选择工具", list(tools), format_func=lambda x: tools[x]["title"]) # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool")
selected_tool = st.selectbox("选择工具", tool_names, format_func=lambda x: tools.get(x,{"title": "None"})["title"], key="selected_tool")
selected_tools = [selected_tool] selected_tools = [selected_tool]
selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools} selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools}
if "None" in selected_tools:
selected_tools.remove("None")
# 当不启用Agent时手动生成工具参数 # 当不启用Agent时手动生成工具参数
# TODO: 需要更精细的控制控件 # TODO: 需要更精细的控制控件
tool_input = {} tool_input = {}
@ -151,35 +183,22 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
else: else:
tool_input[k] = st.text_input(v["title"], v.get("default")) tool_input[k] = st.text_input(v["title"], v.get("default"))
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False) uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
with tab2: with tab2:
# 会话 # 会话
conv_names = list(st.session_state["conversation_ids"].keys()) cols = st.columns(3)
index = 0 conv_names = chat_box.get_chat_names()
if st.session_state.get("cur_conv_name") in conv_names: conversation_name = sac.buttons(conv_names, label="当前会话:", key="cur_conv_name")
index = conv_names.index(st.session_state.get("cur_conv_name"))
conversation_name = st.selectbox("当前会话", conv_names, index=index)
chat_box.use_chat_name(conversation_name) chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name] conversation_id = chat_box.context["uid"]
if cols[0].button("新建", on_click=add_conv):
# 模型 ...
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS] if cols[1].button("重命名"):
platform = st.selectbox("选择模型平台", platforms) rename_conversation()
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform)) if cols[2].button("删除", on_click=del_conv):
llm_model = st.selectbox("选择LLM模型", llm_models) ...
# 传入后端的内容
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
for key in LLM_MODEL_CONFIG:
if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
if llm_model is not None:
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
# Display chat messages from history on app rerun # Display chat messages from history on app rerun
chat_box.output_messages() chat_box.output_messages()
@ -203,10 +222,27 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
# "optional_text_label": "欢迎反馈您打分的理由", # "optional_text_label": "欢迎反馈您打分的理由",
# } # }
if prompt := st.chat_input(chat_input_placeholder, key="prompt"): # 传入后端的内容
if parse_command(text=prompt, modal=modal): chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
st.rerun() for key in LLM_MODEL_CONFIG:
else: if LLM_MODEL_CONFIG[key]:
first_key = next(iter(LLM_MODEL_CONFIG[key]))
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
llm_model = ctx.get("llm_model")
if llm_model is not None:
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
# chat input
with bottom():
cols = st.columns([1, 1, 15])
if cols[0].button(":atom_symbol:"):
widget_keys = ["platform", "llm_model", "temperature", "system_message"]
chat_box.context_to_session(include=widget_keys)
llm_model_setting()
# with cols[1]:
# mic_audio = audio_recorder("", icon_size="2x", key="mic_audio")
prompt = cols[2].chat_input(chat_input_placeholder, key="prompt")
if prompt:
history = get_messages_history( history = get_messages_history(
chat_model_config["llm_model"].get(next(iter(chat_model_config["llm_model"])), {}).get("history_len", 1) chat_model_config["llm_model"].get(next(iter(chat_model_config["llm_model"])), {}).get("history_len", 1)
) )
@ -343,12 +379,9 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
# chat_box.update_msg(text, element_index=0) # chat_box.update_msg(text, element_index=0)
# chat_box.update_msg(text, element_index=0, streaming=False) # chat_box.update_msg(text, element_index=0, streaming=False)
# chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) # chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
if st.session_state.get("need_rerun"):
st.session_state["need_rerun"] = False
st.rerun()
now = datetime.now() now = datetime.now()
with tab1: with tab2:
cols = st.columns(2) cols = st.columns(2)
export_btn = cols[0] export_btn = cols[0]
if cols[1].button( if cols[1].button(
@ -356,9 +389,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
use_container_width=True, use_container_width=True,
): ):
chat_box.reset_history() chat_box.reset_history()
st.rerun() rerun()
warning_placeholder = st.empty()
export_btn.download_button( export_btn.download_button(
"导出记录", "导出记录",

View File

@ -690,6 +690,14 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
return "" return ""
def get_img_url(file_name: str) -> str:
'''
image url used in streamlit.
absolute local path not working on windows.
'''
return f"{api_address()}/img/{file_name}"
if __name__ == "__main__": if __name__ == "__main__":
api = ApiRequest() api = ApiRequest()
aapi = AsyncApiRequest() aapi = AsyncApiRequest()

View File

@ -49,9 +49,9 @@ python-multipart==0.0.9
# WebUI requirements # WebUI requirements
streamlit==1.30.0 streamlit==1.34.0
streamlit-option-menu==0.3.12 streamlit-antd-components==0.3.2
streamlit-antd-components==0.3.1 streamlit-chatbox==1.1.12
streamlit-chatbox==1.1.11
streamlit-modal==0.1.0
streamlit-aggrid==0.3.4.post3 streamlit-aggrid==0.3.4.post3
streamlit-extras==0.4.2
# audio-recorder-streamlit==0.0.8