mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-03-21 09:36:39 +08:00
- 新功能 (#3944)
- streamlit 更新到 1.34,webui 支持 Dialog 操作
- streamlit-chatbox 更新到 1.1.12,更好的多会话支持
- 开发者
- 在 API 中增加项目图片路由(/img/{file_name}),方便前端使用
This commit is contained in:
parent
98cb1aaf79
commit
9fddd07afd
@ -11,8 +11,11 @@ langchain.verbose = False
|
||||
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# chatchat 项目根目录
|
||||
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
||||
|
||||
# 用户数据根目录
|
||||
DATA_PATH = str(Path(__file__).absolute().parent.parent / "data")
|
||||
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
|
||||
if not os.path.exists(DATA_PATH):
|
||||
os.mkdir(DATA_PATH)
|
||||
|
||||
|
||||
@ -19,6 +19,10 @@ PROMPT_TEMPLATES = {
|
||||
'{history}\n'
|
||||
'Human: {input}\n'
|
||||
'AI:',
|
||||
"rag":
|
||||
'【指令】根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。\n\n'
|
||||
'【已知信息】{context}\n\n'
|
||||
'【问题】{question}\n',
|
||||
},
|
||||
"action_model": {
|
||||
"GPT-4":
|
||||
|
||||
@ -16,7 +16,7 @@ from openai.types.file_object import FileObject
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from .api_schemas import *
|
||||
from chatchat.configs import logger, BASE_TEMP_DIR
|
||||
from chatchat.configs import logger, BASE_TEMP_DIR, log_verbose
|
||||
from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient
|
||||
|
||||
|
||||
@ -126,6 +126,8 @@ async def create_chat_completions(
|
||||
request: Request,
|
||||
body: OpenAIChatInput,
|
||||
):
|
||||
if log_verbose:
|
||||
print(body)
|
||||
async with get_model_client(body.model) as client:
|
||||
result = await openai_request(client.chat.completions.create, body)
|
||||
return result
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import FastAPI, Body
|
||||
@ -7,7 +8,7 @@ from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import RedirectResponse
|
||||
import uvicorn
|
||||
|
||||
from chatchat.configs import VERSION, MEDIA_PATH
|
||||
from chatchat.configs import VERSION, MEDIA_PATH, CHATCHAT_ROOT
|
||||
from chatchat.configs.server_config import OPEN_CROSS_DOMAIN
|
||||
from chatchat.server.api_server.chat_routes import chat_router
|
||||
from chatchat.server.api_server.kb_routes import kb_router
|
||||
@ -55,6 +56,10 @@ def create_app(run_mode: str=None):
|
||||
# 媒体文件
|
||||
app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media")
|
||||
|
||||
# 项目相关图片
|
||||
img_dir = os.path.join(CHATCHAT_ROOT, "img")
|
||||
app.mount("/img", StaticFiles(directory=img_dir), name="img")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@ -202,6 +202,16 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||
message_id=message_id,
|
||||
)
|
||||
yield ret.model_dump_json()
|
||||
yield OpenAIChatOutput( # return blank text lastly
|
||||
id=f"chat{uuid.uuid4()}",
|
||||
object="chat.completion.chunk",
|
||||
content="",
|
||||
role="assistant",
|
||||
model=models["llm_model"].model_name,
|
||||
status = data["status"],
|
||||
message_type = data["message_type"],
|
||||
message_id=message_id,
|
||||
)
|
||||
await task
|
||||
|
||||
if stream:
|
||||
|
||||
@ -664,9 +664,6 @@ def get_httpx_client(
|
||||
# construct Client
|
||||
kwargs.update(timeout=timeout, proxies=default_proxies)
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{get_httpx_client.__class__.__name__}:kwargs: {kwargs}')
|
||||
|
||||
if use_async:
|
||||
return httpx.AsyncClient(**kwargs)
|
||||
else:
|
||||
|
||||
@ -1,30 +1,23 @@
|
||||
import streamlit as st
|
||||
|
||||
# from chatchat.webui_pages.loom_view_client import update_store
|
||||
# from chatchat.webui_pages.openai_plugins import openai_plugins_page
|
||||
from chatchat.webui_pages.utils import *
|
||||
from streamlit_option_menu import option_menu
|
||||
from chatchat.webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
||||
from chatchat.webui_pages.knowledge_base.knowledge_base import knowledge_base_page
|
||||
import os
|
||||
import sys
|
||||
|
||||
import streamlit as st
|
||||
import streamlit_antd_components as sac
|
||||
|
||||
from chatchat.configs import VERSION
|
||||
from chatchat.server.utils import api_address
|
||||
from chatchat.webui_pages.utils import *
|
||||
from chatchat.webui_pages.dialogue.dialogue import dialogue_page, chat_box
|
||||
from chatchat.webui_pages.knowledge_base.knowledge_base import knowledge_base_page
|
||||
|
||||
|
||||
# def on_change(key):
|
||||
# if key:
|
||||
# update_store()
|
||||
img_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
api = ApiRequest(base_url=api_address())
|
||||
|
||||
if __name__ == "__main__":
|
||||
is_lite = "lite" in sys.argv
|
||||
is_lite = "lite" in sys.argv # TODO: remove lite mode
|
||||
|
||||
st.set_page_config(
|
||||
"Langchain-Chatchat WebUI",
|
||||
os.path.join(img_dir, "img", "chatchat_icon_blue_square_v2.png"),
|
||||
get_img_url("chatchat_icon_blue_square_v2.png"),
|
||||
initial_sidebar_state="expanded",
|
||||
menu_items={
|
||||
'Get Help': 'https://github.com/chatchat-space/Langchain-Chatchat',
|
||||
@ -32,66 +25,47 @@ if __name__ == "__main__":
|
||||
'About': f"""欢迎使用 Langchain-Chatchat WebUI {VERSION}!"""
|
||||
},
|
||||
layout="wide"
|
||||
|
||||
)
|
||||
|
||||
# use the following code to set the app to wide mode and the html markdown to increase the sidebar width
|
||||
st.markdown(
|
||||
"""
|
||||
<style>
|
||||
[data-testid="stSidebar"][aria-expanded="true"] > div:first-child{
|
||||
width: 350px;
|
||||
[data-testid="stSidebarUserContent"] {
|
||||
padding-top: 20px;
|
||||
}
|
||||
[data-testid="stSidebar"][aria-expanded="false"] > div:first-child{
|
||||
width: 600px;
|
||||
margin-left: -600px;
|
||||
.block-container {
|
||||
padding-top: 25px;
|
||||
}
|
||||
[data-testid="stBottomBlockContainer"] {
|
||||
padding-bottom: 20px;
|
||||
}
|
||||
|
||||
""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
pages = {
|
||||
"对话": {
|
||||
"icon": "chat",
|
||||
"func": dialogue_page,
|
||||
},
|
||||
"知识库管理": {
|
||||
"icon": "hdd-stack",
|
||||
"func": knowledge_base_page,
|
||||
},
|
||||
# "模型服务": {
|
||||
# "icon": "hdd-stack",
|
||||
# "func": openai_plugins_page,
|
||||
# },
|
||||
}
|
||||
# 更新状态
|
||||
# if "status" not in st.session_state \
|
||||
# or "run_plugins_list" not in st.session_state \
|
||||
# or "launch_subscribe_info" not in st.session_state \
|
||||
# or "list_running_models" not in st.session_state \
|
||||
# or "model_config" not in st.session_state:
|
||||
# update_store()
|
||||
|
||||
with st.sidebar:
|
||||
st.image(
|
||||
os.path.join(img_dir, "img", 'logo-long-chatchat-trans-v2.png'),
|
||||
get_img_url('logo-long-chatchat-trans-v2.png'),
|
||||
use_column_width=True
|
||||
)
|
||||
st.caption(
|
||||
f"""<p align="right">当前版本:{VERSION}</p>""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
options = list(pages)
|
||||
icons = [x["icon"] for x in pages.values()]
|
||||
|
||||
default_index = 0
|
||||
selected_page = option_menu(
|
||||
menu_title="",
|
||||
selected_page = sac.menu(
|
||||
[
|
||||
sac.MenuItem("对话", icon="chat"),
|
||||
sac.MenuItem("知识库管理", icon="hdd-stack"),
|
||||
],
|
||||
key="selected_page",
|
||||
options=options,
|
||||
icons=icons,
|
||||
# menu_icon="chat-quote",
|
||||
default_index=default_index,
|
||||
open_index=0
|
||||
)
|
||||
|
||||
if selected_page in pages:
|
||||
pages[selected_page]["func"](api=api, is_lite=is_lite)
|
||||
sac.divider()
|
||||
|
||||
if selected_page == "知识库管理":
|
||||
knowledge_base_page(api=api, is_lite=is_lite)
|
||||
else:
|
||||
dialogue_page(api=api, is_lite=is_lite)
|
||||
|
||||
@ -1,37 +1,46 @@
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
|
||||
import streamlit as st
|
||||
from streamlit_antd_components.utils import ParseItems
|
||||
|
||||
# from audio_recorder_streamlit import audio_recorder
|
||||
import openai
|
||||
import streamlit as st
|
||||
import streamlit_antd_components as sac
|
||||
from streamlit_chatbox import *
|
||||
from streamlit_modal import Modal
|
||||
from datetime import datetime
|
||||
from streamlit_extras.bottom_container import bottom
|
||||
|
||||
from chatchat.configs import (LLM_MODEL_CONFIG, SUPPORT_AGENT_MODELS, MODEL_PLATFORMS)
|
||||
from chatchat.configs import (LLM_MODEL_CONFIG, TEMPERATURE, MODEL_PLATFORMS, DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL)
|
||||
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
|
||||
from chatchat.server.utils import MsgType, get_config_models
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from chatchat.webui_pages.utils import *
|
||||
from chatchat.webui_pages.dialogue.utils import process_files
|
||||
|
||||
|
||||
img_dir = (Path(__file__).absolute().parent.parent.parent)
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
img_dir,
|
||||
"img",
|
||||
"chatchat_icon_blue_square_v2.png"
|
||||
)
|
||||
assistant_avatar=get_img_url("chatchat_icon_blue_square_v2.png")
|
||||
)
|
||||
|
||||
|
||||
def save_session():
|
||||
'''save session state to chat context'''
|
||||
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||
|
||||
|
||||
def restore_session():
|
||||
'''restore sesstion state from chat context'''
|
||||
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||
|
||||
|
||||
def rerun():
|
||||
'''
|
||||
save chat context before rerun
|
||||
'''
|
||||
save_session()
|
||||
st.rerun()
|
||||
|
||||
|
||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||
'''
|
||||
返回消息历史。
|
||||
@ -49,7 +58,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
||||
"content": "\n\n".join(content),
|
||||
}
|
||||
|
||||
return chat_box.filter_history(history_len=history_len, filter=filter)
|
||||
messages = chat_box.filter_history(history_len=history_len, filter=filter)
|
||||
if sys_msg := st.session_state.get("system_message"):
|
||||
messages = [{"role": "system", "content": sys_msg}] + messages
|
||||
return messages
|
||||
|
||||
|
||||
@st.cache_data
|
||||
@ -61,80 +73,100 @@ def upload_temp_docs(files, _api: ApiRequest) -> str:
|
||||
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
||||
|
||||
|
||||
def parse_command(text: str, modal: Modal) -> bool:
|
||||
'''
|
||||
检查用户是否输入了自定义命令,当前支持:
|
||||
/new {session_name}。如果未提供名称,默认为“会话X”
|
||||
/del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。
|
||||
/clear {session_name}。如果未提供名称,默认清除当前会话
|
||||
/stop {session_name}。如果未提供名称,默认停止当前会话
|
||||
/help。查看命令帮助
|
||||
返回值:输入的是命令返回True,否则返回False
|
||||
'''
|
||||
if m := re.match(r"/([^\s]+)\s*(.*)", text):
|
||||
cmd, name = m.groups()
|
||||
name = name.strip()
|
||||
conv_names = chat_box.get_chat_names()
|
||||
if cmd == "help":
|
||||
modal.open()
|
||||
elif cmd == "new":
|
||||
if not name:
|
||||
i = 1
|
||||
while True:
|
||||
name = f"会话{i}"
|
||||
if name not in conv_names:
|
||||
break
|
||||
i += 1
|
||||
if name in st.session_state["conversation_ids"]:
|
||||
st.error(f"该会话名称 “{name}” 已存在")
|
||||
time.sleep(1)
|
||||
else:
|
||||
st.session_state["conversation_ids"][name] = uuid.uuid4().hex
|
||||
st.session_state["cur_conv_name"] = name
|
||||
elif cmd == "del":
|
||||
name = name or st.session_state.get("cur_conv_name")
|
||||
if len(conv_names) == 1:
|
||||
st.error("这是最后一个会话,无法删除")
|
||||
time.sleep(1)
|
||||
elif not name or name not in st.session_state["conversation_ids"]:
|
||||
st.error(f"无效的会话名称:“{name}”")
|
||||
time.sleep(1)
|
||||
else:
|
||||
st.session_state["conversation_ids"].pop(name, None)
|
||||
chat_box.del_chat_name(name)
|
||||
st.session_state["cur_conv_name"] = ""
|
||||
elif cmd == "clear":
|
||||
chat_box.reset_history(name=name or None)
|
||||
return True
|
||||
return False
|
||||
def add_conv(name: str=""):
|
||||
conv_names = chat_box.get_chat_names()
|
||||
if not name:
|
||||
i = len(conv_names) + 1
|
||||
while True:
|
||||
name = f"会话{i}"
|
||||
if name not in conv_names:
|
||||
break
|
||||
i += 1
|
||||
if name in conv_names:
|
||||
sac.alert("创建新会话出错", f"该会话名称 “{name}” 已存在", color="error", closable=True)
|
||||
else:
|
||||
chat_box.use_chat_name(name)
|
||||
st.session_state["cur_conv_name"] = name
|
||||
|
||||
|
||||
def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
st.session_state.setdefault("conversation_ids", {})
|
||||
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
|
||||
st.session_state.setdefault("file_chat_id", None)
|
||||
def del_conv(name: str=None):
|
||||
conv_names = chat_box.get_chat_names()
|
||||
name = name or chat_box.cur_chat_name
|
||||
if len(conv_names) == 1:
|
||||
sac.alert("删除会话出错", f"这是最后一个会话,无法删除", color="error", closable=True)
|
||||
elif not name or name not in conv_names:
|
||||
sac.alert("删除会话出错", f"无效的会话名称:“{name}”", color="error", closable=True)
|
||||
else:
|
||||
chat_box.del_chat_name(name)
|
||||
restore_session()
|
||||
st.session_state["cur_conv_name"] = chat_box.cur_chat_name
|
||||
|
||||
# 弹出自定义命令帮助信息
|
||||
modal = Modal("自定义命令", key="cmd_help", max_width="500")
|
||||
if modal.is_open():
|
||||
with modal.container():
|
||||
cmds = [x for x in parse_command.__doc__.split("\n") if x.strip().startswith("/")]
|
||||
st.write("\n\n".join(cmds))
|
||||
|
||||
def clear_conv(name: str=None):
|
||||
chat_box.reset_history(name=name or None)
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def list_tools(_api: ApiRequest):
|
||||
return _api.list_tools()
|
||||
|
||||
|
||||
def dialogue_page(
|
||||
api: ApiRequest,
|
||||
is_lite: bool = False,
|
||||
):
|
||||
ctx = chat_box.context
|
||||
ctx.setdefault("uid", uuid.uuid4().hex)
|
||||
ctx.setdefault("file_chat_id", None)
|
||||
ctx.setdefault("llm_model", DEFAULT_LLM_MODEL)
|
||||
ctx.setdefault("temperature", TEMPERATURE)
|
||||
st.session_state.setdefault("cur_conv_name", chat_box.cur_chat_name)
|
||||
restore_session()
|
||||
|
||||
# st.write(chat_box.cur_chat_name)
|
||||
# st.write(st.session_state)
|
||||
|
||||
@st.experimental_dialog("模型配置", width="large")
|
||||
def llm_model_setting():
|
||||
# 模型
|
||||
cols = st.columns(3)
|
||||
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
|
||||
platform = cols[0].selectbox("选择模型平台", platforms, key="platform")
|
||||
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
|
||||
llm_model = cols[1].selectbox("选择LLM模型", llm_models, key="llm_model")
|
||||
temperature = cols[2].slider("Temperature", 0.0, 1.0, key="temperature")
|
||||
system_message = st.text_area("System Message:", key="system_message")
|
||||
if st.button("OK"):
|
||||
rerun()
|
||||
|
||||
@st.experimental_dialog("重命名会话")
|
||||
def rename_conversation():
|
||||
name = st.text_input("会话名称")
|
||||
if st.button("OK"):
|
||||
chat_box.change_chat_name(name)
|
||||
restore_session()
|
||||
st.session_state["cur_conv_name"] = name
|
||||
rerun()
|
||||
|
||||
with st.sidebar:
|
||||
tab1, tab2 = st.tabs(["对话设置", "模型设置"])
|
||||
tab1, tab2 = st.tabs(["工具设置", "会话设置"])
|
||||
|
||||
with tab1:
|
||||
use_agent = st.checkbox("启用Agent", True, help="请确保选择的模型具备Agent能力")
|
||||
use_agent = st.checkbox("启用Agent", help="请确保选择的模型具备Agent能力", key="use_agent")
|
||||
# 选择工具
|
||||
tools = api.list_tools()
|
||||
tools = list_tools(api)
|
||||
tool_names = ["None"] + list(tools)
|
||||
if use_agent:
|
||||
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
|
||||
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools")
|
||||
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], key="selected_tools")
|
||||
else:
|
||||
selected_tool = st.selectbox("选择工具", list(tools), format_func=lambda x: tools[x]["title"])
|
||||
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool")
|
||||
selected_tool = st.selectbox("选择工具", tool_names, format_func=lambda x: tools.get(x,{"title": "None"})["title"], key="selected_tool")
|
||||
selected_tools = [selected_tool]
|
||||
selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools}
|
||||
|
||||
if "None" in selected_tools:
|
||||
selected_tools.remove("None")
|
||||
# 当不启用Agent时,手动生成工具参数
|
||||
# TODO: 需要更精细的控制控件
|
||||
tool_input = {}
|
||||
@ -151,35 +183,22 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
else:
|
||||
tool_input[k] = st.text_input(v["title"], v.get("default"))
|
||||
|
||||
|
||||
uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
|
||||
files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
|
||||
|
||||
|
||||
with tab2:
|
||||
# 会话
|
||||
conv_names = list(st.session_state["conversation_ids"].keys())
|
||||
index = 0
|
||||
if st.session_state.get("cur_conv_name") in conv_names:
|
||||
index = conv_names.index(st.session_state.get("cur_conv_name"))
|
||||
conversation_name = st.selectbox("当前会话", conv_names, index=index)
|
||||
cols = st.columns(3)
|
||||
conv_names = chat_box.get_chat_names()
|
||||
conversation_name = sac.buttons(conv_names, label="当前会话:", key="cur_conv_name")
|
||||
chat_box.use_chat_name(conversation_name)
|
||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||
|
||||
# 模型
|
||||
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
|
||||
platform = st.selectbox("选择模型平台", platforms)
|
||||
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
|
||||
llm_model = st.selectbox("选择LLM模型", llm_models)
|
||||
|
||||
# 传入后端的内容
|
||||
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||
for key in LLM_MODEL_CONFIG:
|
||||
if LLM_MODEL_CONFIG[key]:
|
||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
||||
|
||||
if llm_model is not None:
|
||||
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
||||
conversation_id = chat_box.context["uid"]
|
||||
if cols[0].button("新建", on_click=add_conv):
|
||||
...
|
||||
if cols[1].button("重命名"):
|
||||
rename_conversation()
|
||||
if cols[2].button("删除", on_click=del_conv):
|
||||
...
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
chat_box.output_messages()
|
||||
@ -203,152 +222,166 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
# "optional_text_label": "欢迎反馈您打分的理由",
|
||||
# }
|
||||
|
||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||
if parse_command(text=prompt, modal=modal):
|
||||
st.rerun()
|
||||
# 传入后端的内容
|
||||
chat_model_config = {key: {} for key in LLM_MODEL_CONFIG.keys()}
|
||||
for key in LLM_MODEL_CONFIG:
|
||||
if LLM_MODEL_CONFIG[key]:
|
||||
first_key = next(iter(LLM_MODEL_CONFIG[key]))
|
||||
chat_model_config[key][first_key] = LLM_MODEL_CONFIG[key][first_key]
|
||||
llm_model = ctx.get("llm_model")
|
||||
if llm_model is not None:
|
||||
chat_model_config['llm_model'][llm_model] = LLM_MODEL_CONFIG['llm_model'].get(llm_model, {})
|
||||
|
||||
# chat input
|
||||
with bottom():
|
||||
cols = st.columns([1, 1, 15])
|
||||
if cols[0].button(":atom_symbol:"):
|
||||
widget_keys = ["platform", "llm_model", "temperature", "system_message"]
|
||||
chat_box.context_to_session(include=widget_keys)
|
||||
llm_model_setting()
|
||||
# with cols[1]:
|
||||
# mic_audio = audio_recorder("", icon_size="2x", key="mic_audio")
|
||||
prompt = cols[2].chat_input(chat_input_placeholder, key="prompt")
|
||||
if prompt:
|
||||
history = get_messages_history(
|
||||
chat_model_config["llm_model"].get(next(iter(chat_model_config["llm_model"])), {}).get("history_len", 1)
|
||||
)
|
||||
chat_box.user_say(prompt)
|
||||
if files_upload:
|
||||
if files_upload["images"]:
|
||||
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["videos"]:
|
||||
st.markdown(
|
||||
f'<video width="400" height="300" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["audios"]:
|
||||
st.markdown(
|
||||
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||
unsafe_allow_html=True)
|
||||
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
started = False
|
||||
|
||||
client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE")
|
||||
messages = history + [{"role": "user", "content": prompt}]
|
||||
tools = list(selected_tool_configs)
|
||||
if len(selected_tools) == 1:
|
||||
tool_choice = selected_tools[0]
|
||||
else:
|
||||
history = get_messages_history(
|
||||
chat_model_config["llm_model"].get(next(iter(chat_model_config["llm_model"])), {}).get("history_len", 1)
|
||||
)
|
||||
chat_box.user_say(prompt)
|
||||
if files_upload:
|
||||
if files_upload["images"]:
|
||||
st.markdown(f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["videos"]:
|
||||
st.markdown(
|
||||
f'<video width="400" height="300" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
|
||||
unsafe_allow_html=True)
|
||||
elif files_upload["audios"]:
|
||||
st.markdown(
|
||||
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
|
||||
unsafe_allow_html=True)
|
||||
tool_choice = None
|
||||
# 如果 tool_input 中有空的字段,设为用户输入
|
||||
for k in tool_input:
|
||||
if tool_input[k] in [None, ""]:
|
||||
tool_input[k] = prompt
|
||||
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
started = False
|
||||
extra_body = dict(
|
||||
metadata=files_upload,
|
||||
chat_model_config=chat_model_config,
|
||||
conversation_id=conversation_id,
|
||||
tool_input = tool_input,
|
||||
)
|
||||
for d in client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
extra_body=extra_body,
|
||||
):
|
||||
# print("\n\n", d.status, "\n", d, "\n\n")
|
||||
message_id = d.message_id
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE")
|
||||
messages = history + [{"role": "user", "content": prompt}]
|
||||
tools = list(selected_tool_configs)
|
||||
if len(selected_tools) == 1:
|
||||
tool_choice = selected_tools[0]
|
||||
else:
|
||||
tool_choice = None
|
||||
# 如果 tool_input 中有空的字段,设为用户输入
|
||||
for k in tool_input:
|
||||
if tool_input[k] in [None, ""]:
|
||||
tool_input[k] = prompt
|
||||
# clear initial message
|
||||
if not started:
|
||||
chat_box.update_msg("", streaming=False)
|
||||
started = True
|
||||
|
||||
extra_body = dict(
|
||||
metadata=files_upload,
|
||||
chat_model_config=chat_model_config,
|
||||
conversation_id=conversation_id,
|
||||
tool_input = tool_input,
|
||||
)
|
||||
for d in client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=llm_model,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
extra_body=extra_body,
|
||||
):
|
||||
# print("\n\n", d.status, "\n", d, "\n\n")
|
||||
message_id = d.message_id
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
}
|
||||
|
||||
# clear initial message
|
||||
if not started:
|
||||
chat_box.update_msg("", streaming=False)
|
||||
started = True
|
||||
|
||||
if d.status == AgentStatus.error:
|
||||
st.error(d.choices[0].delta.content)
|
||||
elif d.status == AgentStatus.llm_start:
|
||||
chat_box.insert_msg("正在解读工具输出结果...")
|
||||
text = d.choices[0].delta.content or ""
|
||||
elif d.status == AgentStatus.llm_new_token:
|
||||
if d.status == AgentStatus.error:
|
||||
st.error(d.choices[0].delta.content)
|
||||
elif d.status == AgentStatus.llm_start:
|
||||
chat_box.insert_msg("正在解读工具输出结果...")
|
||||
text = d.choices[0].delta.content or ""
|
||||
elif d.status == AgentStatus.llm_new_token:
|
||||
text += d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
|
||||
elif d.status == AgentStatus.llm_end:
|
||||
text += d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=False, metadata=metadata)
|
||||
# tool 的输出与 llm 输出重复了
|
||||
# elif d.status == AgentStatus.tool_start:
|
||||
# formatted_data = {
|
||||
# "Function": d.choices[0].delta.tool_calls[0].function.name,
|
||||
# "function_input": d.choices[0].delta.tool_calls[0].function.arguments,
|
||||
# }
|
||||
# formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||
# text = """\n```{}\n```\n""".format(formatted_json)
|
||||
# chat_box.insert_msg( # TODO: insert text directly not shown
|
||||
# Markdown(text, title="Function call", in_expander=True, expanded=True, state="running"))
|
||||
# elif d.status == AgentStatus.tool_end:
|
||||
# tool_output = d.choices[0].delta.tool_calls[0].tool_output
|
||||
# if d.message_type == MsgType.IMAGE:
|
||||
# for url in json.loads(tool_output).get("images", []):
|
||||
# url = f"{api.base_url}/media/{url}"
|
||||
# chat_box.insert_msg(Image(url))
|
||||
# chat_box.update_msg(expanded=False, state="complete")
|
||||
# else:
|
||||
# text += """\n```\nObservation:\n{}\n```\n""".format(tool_output)
|
||||
# chat_box.update_msg(text, streaming=False, expanded=False, state="complete")
|
||||
elif d.status == AgentStatus.agent_finish:
|
||||
text = d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"))
|
||||
elif d.status == None: # not agent chat
|
||||
if getattr(d, "is_ref", False):
|
||||
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料"))
|
||||
chat_box.insert_msg("")
|
||||
else:
|
||||
text += d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
|
||||
elif d.status == AgentStatus.llm_end:
|
||||
text += d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=False, metadata=metadata)
|
||||
# tool 的输出与 llm 输出重复了
|
||||
# elif d.status == AgentStatus.tool_start:
|
||||
# formatted_data = {
|
||||
# "Function": d.choices[0].delta.tool_calls[0].function.name,
|
||||
# "function_input": d.choices[0].delta.tool_calls[0].function.arguments,
|
||||
# }
|
||||
# formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
|
||||
# text = """\n```{}\n```\n""".format(formatted_json)
|
||||
# chat_box.insert_msg( # TODO: insert text directly not shown
|
||||
# Markdown(text, title="Function call", in_expander=True, expanded=True, state="running"))
|
||||
# elif d.status == AgentStatus.tool_end:
|
||||
# tool_output = d.choices[0].delta.tool_calls[0].tool_output
|
||||
# if d.message_type == MsgType.IMAGE:
|
||||
# for url in json.loads(tool_output).get("images", []):
|
||||
# url = f"{api.base_url}/media/{url}"
|
||||
# chat_box.insert_msg(Image(url))
|
||||
# chat_box.update_msg(expanded=False, state="complete")
|
||||
# else:
|
||||
# text += """\n```\nObservation:\n{}\n```\n""".format(tool_output)
|
||||
# chat_box.update_msg(text, streaming=False, expanded=False, state="complete")
|
||||
elif d.status == AgentStatus.agent_finish:
|
||||
text = d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"))
|
||||
elif d.status == None: # not agent chat
|
||||
if getattr(d, "is_ref", False):
|
||||
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料"))
|
||||
chat_box.insert_msg("")
|
||||
else:
|
||||
text += d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True, metadata=metadata)
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata)
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata)
|
||||
|
||||
if os.path.exists("tmp/image.jpg"):
|
||||
with open("tmp/image.jpg", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode()
|
||||
img_tag = f'<img src="data:image/jpeg;base64,{encoded_string}" width="300">'
|
||||
st.markdown(img_tag, unsafe_allow_html=True)
|
||||
os.remove("tmp/image.jpg")
|
||||
# chat_box.show_feedback(**feedback_kwargs,
|
||||
# key=message_id,
|
||||
# on_submit=on_feedback,
|
||||
# kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
if os.path.exists("tmp/image.jpg"):
|
||||
with open("tmp/image.jpg", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode()
|
||||
img_tag = f'<img src="data:image/jpeg;base64,{encoded_string}" width="300">'
|
||||
st.markdown(img_tag, unsafe_allow_html=True)
|
||||
os.remove("tmp/image.jpg")
|
||||
# chat_box.show_feedback(**feedback_kwargs,
|
||||
# key=message_id,
|
||||
# on_submit=on_feedback,
|
||||
# kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||
|
||||
# elif dialogue_mode == "文件对话":
|
||||
# if st.session_state["file_chat_id"] is None:
|
||||
# st.error("请先上传文件再进行对话")
|
||||
# st.stop()
|
||||
# chat_box.ai_say([
|
||||
# f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
||||
# Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
||||
# ])
|
||||
# text = ""
|
||||
# for d in api.file_chat(prompt,
|
||||
# knowledge_id=st.session_state["file_chat_id"],
|
||||
# top_k=kb_top_k,
|
||||
# score_threshold=score_threshold,
|
||||
# history=history,
|
||||
# model=llm_model,
|
||||
# prompt_name=prompt_template_name,
|
||||
# temperature=temperature):
|
||||
# if error_msg := check_error_msg(d):
|
||||
# st.error(error_msg)
|
||||
# elif chunk := d.get("answer"):
|
||||
# text += chunk
|
||||
# chat_box.update_msg(text, element_index=0)
|
||||
# chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
# chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
if st.session_state.get("need_rerun"):
|
||||
st.session_state["need_rerun"] = False
|
||||
st.rerun()
|
||||
# elif dialogue_mode == "文件对话":
|
||||
# if st.session_state["file_chat_id"] is None:
|
||||
# st.error("请先上传文件再进行对话")
|
||||
# st.stop()
|
||||
# chat_box.ai_say([
|
||||
# f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
|
||||
# Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
|
||||
# ])
|
||||
# text = ""
|
||||
# for d in api.file_chat(prompt,
|
||||
# knowledge_id=st.session_state["file_chat_id"],
|
||||
# top_k=kb_top_k,
|
||||
# score_threshold=score_threshold,
|
||||
# history=history,
|
||||
# model=llm_model,
|
||||
# prompt_name=prompt_template_name,
|
||||
# temperature=temperature):
|
||||
# if error_msg := check_error_msg(d):
|
||||
# st.error(error_msg)
|
||||
# elif chunk := d.get("answer"):
|
||||
# text += chunk
|
||||
# chat_box.update_msg(text, element_index=0)
|
||||
# chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
# chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
|
||||
now = datetime.now()
|
||||
with tab1:
|
||||
with tab2:
|
||||
cols = st.columns(2)
|
||||
export_btn = cols[0]
|
||||
if cols[1].button(
|
||||
@ -356,9 +389,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
||||
use_container_width=True,
|
||||
):
|
||||
chat_box.reset_history()
|
||||
st.rerun()
|
||||
|
||||
warning_placeholder = st.empty()
|
||||
rerun()
|
||||
|
||||
export_btn.download_button(
|
||||
"导出记录",
|
||||
|
||||
@ -690,6 +690,14 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def get_img_url(file_name: str) -> str:
|
||||
'''
|
||||
image url used in streamlit.
|
||||
absolute local path not working on windows.
|
||||
'''
|
||||
return f"{api_address()}/img/{file_name}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = ApiRequest()
|
||||
aapi = AsyncApiRequest()
|
||||
|
||||
@ -49,9 +49,9 @@ python-multipart==0.0.9
|
||||
|
||||
# WebUI requirements
|
||||
|
||||
streamlit==1.30.0
|
||||
streamlit-option-menu==0.3.12
|
||||
streamlit-antd-components==0.3.1
|
||||
streamlit-chatbox==1.1.11
|
||||
streamlit-modal==0.1.0
|
||||
streamlit==1.34.0
|
||||
streamlit-antd-components==0.3.2
|
||||
streamlit-chatbox==1.1.12
|
||||
streamlit-aggrid==0.3.4.post3
|
||||
streamlit-extras==0.4.2
|
||||
# audio-recorder-streamlit==0.0.8
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user