From a1429a350a1433d548680f2a712b2aec205562e4 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Fri, 29 Mar 2024 14:30:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=EF=BC=9Asearch=5Flocal=5Fkno?= =?UTF-8?q?wledge=5Fbase=20=E5=B7=A5=E5=85=B7=E8=BF=94=E5=9B=9E=E5=80=BC?= =?UTF-8?q?=E9=94=99=E8=AF=AF=EF=BC=9B/tools=20=E8=B7=AF=E7=94=B1=E9=94=99?= =?UTF-8?q?=E8=AF=AF=EF=BC=9Bwebui=20=E4=B8=AD=E2=80=9C=E6=AD=A3=E5=9C=A8?= =?UTF-8?q?=E6=80=9D=E8=80=83=E2=80=9D=E4=B8=80=E7=9B=B4=E6=98=BE=E7=A4=BA?= =?UTF-8?q?=20(#3571)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tools_factory/search_local_knowledgebase.py | 2 +- .../chatchat/server/api_server/chat_routes.py | 2 +- .../chatchat/server/api_server/tool_routes.py | 2 +- .../chatchat/webui_pages/dialogue/dialogue.py | 12 +++++++----- chatchat-server/chatchat/webui_pages/utils.py | 11 ----------- 5 files changed, 10 insertions(+), 19 deletions(-) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index 1f83d94a..a34e8301 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -52,4 +52,4 @@ def search_local_knowledgebase( '''''' tool_config = get_tool_config("search_local_knowledgebase") ret = search_knowledgebase(query=query, database=database, config=tool_config) - return BaseToolOutput(ret, database=database) + return KBToolOutput(ret, database=database) diff --git a/chatchat-server/chatchat/server/api_server/chat_routes.py b/chatchat-server/chatchat/server/api_server/chat_routes.py index 11a561ff..f935ab9c 100644 --- a/chatchat-server/chatchat/server/api_server/chat_routes.py +++ b/chatchat-server/chatchat/server/api_server/chat_routes.py @@ -99,7 +99,7 @@ async def chat_completions( "status": None, } header = [{**extra_json, - "content": f"知识库参考资料:\n\n{tool_result}\n\n", + "content": f"{tool_result}", "tool_output":tool_result.data, "is_ref": True, }] diff --git a/chatchat-server/chatchat/server/api_server/tool_routes.py b/chatchat-server/chatchat/server/api_server/tool_routes.py index 7355d6b8..3093b40e 100644 --- a/chatchat-server/chatchat/server/api_server/tool_routes.py +++ b/chatchat-server/chatchat/server/api_server/tool_routes.py @@ -11,7 +11,7 @@ from chatchat.server.utils import BaseResponse, get_tool, get_tool_config tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) -@tool_router.get("/", response_model=BaseResponse) +@tool_router.get("", response_model=BaseResponse) async def list_tools(): tools = get_tool() data = {t.name: { diff --git a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index 4a9b3efb..2a12e3ee 100644 --- a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -254,19 +254,21 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): tool_choice=tool_choice, extra_body=extra_body, ): - print("\n\n", d.status, "\n", d, "\n\n") + # print("\n\n", d.status, "\n", d, "\n\n") message_id = d.message_id metadata = { "message_id": message_id, } + # clear initial message + if not started: + chat_box.update_msg("", streaming=False) + started = True + if d.status == AgentStatus.error: st.error(d.choices[0].delta.content) elif d.status == AgentStatus.llm_start: - if not started: - started = True - else: - chat_box.insert_msg("正在解读工具输出结果...") + chat_box.insert_msg("正在解读工具输出结果...") text = d.choices[0].delta.content or "" elif d.status == AgentStatus.llm_new_token: text += d.choices[0].delta.content or "" diff --git a/chatchat-server/chatchat/webui_pages/utils.py b/chatchat-server/chatchat/webui_pages/utils.py index 53fdb2fc..43836077 100644 --- a/chatchat-server/chatchat/webui_pages/utils.py +++ b/chatchat-server/chatchat/webui_pages/utils.py @@ -50,14 +50,6 @@ class ApiRequest: timeout=self.timeout) return self._client - def _check_url(self, url: str) -> str: - ''' - 新版 httpx 强制要求 url 以 / 结尾,否则会返回 307 - ''' - if not url.endswith("/"): - url = url + "/" - return url - def get( self, url: str, @@ -66,7 +58,6 @@ class ApiRequest: stream: bool = False, **kwargs: Any, ) -> Union[httpx.Response, Iterator[httpx.Response], None]: - url = self._check_url(url) while retry > 0: try: if stream: @@ -88,7 +79,6 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: - url = self._check_url(url) while retry > 0: try: # print(kwargs) @@ -111,7 +101,6 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: - url = self._check_url(url) while retry > 0: try: if stream: