一些bug

This commit is contained in:
glide-the 2024-05-07 19:32:54 +08:00
parent b659a1badc
commit 53380d76c2
2 changed files with 26 additions and 18 deletions

View File

@ -32,6 +32,7 @@ ollama:
model_type: 'llm'
model_credentials:
base_url: 'http://172.21.80.1:11434'
mode: 'completion'

View File

@ -11,13 +11,13 @@ import streamlit_antd_components as sac
from streamlit_chatbox import *
from streamlit_extras.bottom_container import bottom
from chatchat.configs import (LLM_MODEL_CONFIG, TEMPERATURE, MODEL_PLATFORMS, DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL)
from chatchat.configs import (LLM_MODEL_CONFIG, TEMPERATURE, MODEL_PLATFORMS, DEFAULT_LLM_MODEL,
DEFAULT_EMBEDDING_MODEL)
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.utils import MsgType, get_config_models
from chatchat.webui_pages.utils import *
from chatchat.webui_pages.dialogue.utils import process_files
chat_box = ChatBox(
assistant_avatar=get_img_url("chatchat_icon_blue_square_v2.png")
)
@ -73,7 +73,7 @@ def upload_temp_docs(files, _api: ApiRequest) -> str:
return _api.upload_temp_docs(files).get("data", {}).get("id")
def add_conv(name: str=""):
def add_conv(name: str = ""):
conv_names = chat_box.get_chat_names()
if not name:
i = len(conv_names) + 1
@ -89,7 +89,7 @@ def add_conv(name: str=""):
st.session_state["cur_conv_name"] = name
def del_conv(name: str=None):
def del_conv(name: str = None):
conv_names = chat_box.get_chat_names()
name = name or chat_box.cur_chat_name
if len(conv_names) == 1:
@ -102,7 +102,7 @@ def del_conv(name: str=None):
st.session_state["cur_conv_name"] = chat_box.cur_chat_name
def clear_conv(name: str=None):
def clear_conv(name: str = None):
chat_box.reset_history(name=name or None)
@ -112,8 +112,8 @@ def list_tools(_api: ApiRequest):
def dialogue_page(
api: ApiRequest,
is_lite: bool = False,
api: ApiRequest,
is_lite: bool = False,
):
ctx = chat_box.context
ctx.setdefault("uid", uuid.uuid4().hex)
@ -132,7 +132,7 @@ def dialogue_page(
cols = st.columns(3)
platforms = ["所有"] + [x["platform_name"] for x in MODEL_PLATFORMS]
platform = cols[0].selectbox("选择模型平台", platforms, key="platform")
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform=="所有" else platform))
llm_models = list(get_config_models(model_type="llm", platform_name=None if platform == "所有" else platform))
llm_model = cols[1].selectbox("选择LLM模型", llm_models, key="llm_model")
temperature = cols[2].slider("Temperature", 0.0, 1.0, key="temperature")
system_message = st.text_area("System Message:", key="system_message")
@ -158,10 +158,13 @@ def dialogue_page(
tool_names = ["None"] + list(tools)
if use_agent:
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools")
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], key="selected_tools")
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
key="selected_tools")
else:
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool")
selected_tool = st.selectbox("选择工具", tool_names, format_func=lambda x: tools.get(x,{"title": "None"})["title"], key="selected_tool")
selected_tool = st.selectbox("选择工具", tool_names,
format_func=lambda x: tools.get(x, {"title": "None"})["title"],
key="selected_tool")
selected_tools = [selected_tool]
selected_tool_configs = {name: tool["config"] for name, tool in tools.items() if name in selected_tools}
@ -277,11 +280,11 @@ def dialogue_page(
tool_input[k] = prompt
extra_body = dict(
metadata=files_upload,
chat_model_config=chat_model_config,
conversation_id=conversation_id,
tool_input = tool_input,
)
metadata=files_upload,
chat_model_config=chat_model_config,
conversation_id=conversation_id,
tool_input=tool_input,
)
for d in client.chat.completions.create(
messages=messages,
model=llm_model,
@ -289,7 +292,7 @@ def dialogue_page(
tools=tools,
tool_choice=tool_choice,
extra_body=extra_body,
):
):
# print("\n\n", d.status, "\n", d, "\n\n")
message_id = d.message_id
metadata = {
@ -335,9 +338,10 @@ def dialogue_page(
elif d.status == AgentStatus.agent_finish:
text = d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"))
elif d.status == None: # not agent chat
elif d.status == None: # not agent chat
if getattr(d, "is_ref", False):
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料"))
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
title="参考资料"))
chat_box.insert_msg("")
else:
text += d.choices[0].delta.content or ""
@ -400,3 +404,6 @@ def dialogue_page(
)
# st.write(chat_box.history)
save_session()