修复:search_local_knowledge_base 工具返回值错误;/tools 路由错误;webui 中“正在思考”一直显示 (#3571)

This commit is contained in:
liunux4odoo 2024-03-29 14:30:16 +08:00 committed by GitHub
parent 42aa900566
commit a1429a350a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 10 additions and 19 deletions

View File

@ -52,4 +52,4 @@ def search_local_knowledgebase(
'''''' ''''''
tool_config = get_tool_config("search_local_knowledgebase") tool_config = get_tool_config("search_local_knowledgebase")
ret = search_knowledgebase(query=query, database=database, config=tool_config) ret = search_knowledgebase(query=query, database=database, config=tool_config)
return BaseToolOutput(ret, database=database) return KBToolOutput(ret, database=database)

View File

@ -99,7 +99,7 @@ async def chat_completions(
"status": None, "status": None,
} }
header = [{**extra_json, header = [{**extra_json,
"content": f"知识库参考资料:\n\n{tool_result}\n\n", "content": f"{tool_result}",
"tool_output":tool_result.data, "tool_output":tool_result.data,
"is_ref": True, "is_ref": True,
}] }]

View File

@ -11,7 +11,7 @@ from chatchat.server.utils import BaseResponse, get_tool, get_tool_config
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@tool_router.get("/", response_model=BaseResponse) @tool_router.get("", response_model=BaseResponse)
async def list_tools(): async def list_tools():
tools = get_tool() tools = get_tool()
data = {t.name: { data = {t.name: {

View File

@ -254,19 +254,21 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
tool_choice=tool_choice, tool_choice=tool_choice,
extra_body=extra_body, extra_body=extra_body,
): ):
print("\n\n", d.status, "\n", d, "\n\n") # print("\n\n", d.status, "\n", d, "\n\n")
message_id = d.message_id message_id = d.message_id
metadata = { metadata = {
"message_id": message_id, "message_id": message_id,
} }
# clear initial message
if not started:
chat_box.update_msg("", streaming=False)
started = True
if d.status == AgentStatus.error: if d.status == AgentStatus.error:
st.error(d.choices[0].delta.content) st.error(d.choices[0].delta.content)
elif d.status == AgentStatus.llm_start: elif d.status == AgentStatus.llm_start:
if not started: chat_box.insert_msg("正在解读工具输出结果...")
started = True
else:
chat_box.insert_msg("正在解读工具输出结果...")
text = d.choices[0].delta.content or "" text = d.choices[0].delta.content or ""
elif d.status == AgentStatus.llm_new_token: elif d.status == AgentStatus.llm_new_token:
text += d.choices[0].delta.content or "" text += d.choices[0].delta.content or ""

View File

@ -50,14 +50,6 @@ class ApiRequest:
timeout=self.timeout) timeout=self.timeout)
return self._client return self._client
def _check_url(self, url: str) -> str:
'''
新版 httpx 强制要求 url / 结尾否则会返回 307
'''
if not url.endswith("/"):
url = url + "/"
return url
def get( def get(
self, self,
url: str, url: str,
@ -66,7 +58,6 @@ class ApiRequest:
stream: bool = False, stream: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0: while retry > 0:
try: try:
if stream: if stream:
@ -88,7 +79,6 @@ class ApiRequest:
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0: while retry > 0:
try: try:
# print(kwargs) # print(kwargs)
@ -111,7 +101,6 @@ class ApiRequest:
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0: while retry > 0:
try: try:
if stream: if stream: