From a1429a350a1433d548680f2a712b2aec205562e4 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Fri, 29 Mar 2024 14:30:16 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=EF=BC=9Asearch=5Flocal?= =?UTF-8?q?=5Fknowledge=5Fbase=20=E5=B7=A5=E5=85=B7=E8=BF=94=E5=9B=9E?= =?UTF-8?q?=E5=80=BC=E9=94=99=E8=AF=AF=EF=BC=9B/tools=20=E8=B7=AF=E7=94=B1?= =?UTF-8?q?=E9=94=99=E8=AF=AF=EF=BC=9Bwebui=20=E4=B8=AD=E2=80=9C=E6=AD=A3?= =?UTF-8?q?=E5=9C=A8=E6=80=9D=E8=80=83=E2=80=9D=E4=B8=80=E7=9B=B4=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=20(#3571)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tools_factory/search_local_knowledgebase.py | 2 +- .../chatchat/server/api_server/chat_routes.py | 2 +- .../chatchat/server/api_server/tool_routes.py | 2 +- .../chatchat/webui_pages/dialogue/dialogue.py | 12 +++++++----- chatchat-server/chatchat/webui_pages/utils.py | 11 ----------- 5 files changed, 10 insertions(+), 19 deletions(-) diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index 1f83d94a..a34e8301 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -52,4 +52,4 @@ def search_local_knowledgebase( '''''' tool_config = get_tool_config("search_local_knowledgebase") ret = search_knowledgebase(query=query, database=database, config=tool_config) - return BaseToolOutput(ret, database=database) + return KBToolOutput(ret, database=database) diff --git a/chatchat-server/chatchat/server/api_server/chat_routes.py b/chatchat-server/chatchat/server/api_server/chat_routes.py index 11a561ff..f935ab9c 100644 --- a/chatchat-server/chatchat/server/api_server/chat_routes.py +++ b/chatchat-server/chatchat/server/api_server/chat_routes.py @@ -99,7 +99,7 @@ async def chat_completions( "status": None, } header = [{**extra_json, - "content": f"知识库参考资料:\n\n{tool_result}\n\n", + "content": f"{tool_result}", "tool_output":tool_result.data, "is_ref": True, }] diff --git a/chatchat-server/chatchat/server/api_server/tool_routes.py b/chatchat-server/chatchat/server/api_server/tool_routes.py index 7355d6b8..3093b40e 100644 --- a/chatchat-server/chatchat/server/api_server/tool_routes.py +++ b/chatchat-server/chatchat/server/api_server/tool_routes.py @@ -11,7 +11,7 @@ from chatchat.server.utils import BaseResponse, get_tool, get_tool_config tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) -@tool_router.get("/", response_model=BaseResponse) +@tool_router.get("", response_model=BaseResponse) async def list_tools(): tools = get_tool() data = {t.name: { diff --git a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index 4a9b3efb..2a12e3ee 100644 --- a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -254,19 +254,21 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): tool_choice=tool_choice, extra_body=extra_body, ): - print("\n\n", d.status, "\n", d, "\n\n") + # print("\n\n", d.status, "\n", d, "\n\n") message_id = d.message_id metadata = { "message_id": message_id, } + # clear initial message + if not started: + chat_box.update_msg("", streaming=False) + started = True + if d.status == AgentStatus.error: st.error(d.choices[0].delta.content) elif d.status == AgentStatus.llm_start: - if not started: - started = True - else: - chat_box.insert_msg("正在解读工具输出结果...") + chat_box.insert_msg("正在解读工具输出结果...") text = d.choices[0].delta.content or "" elif d.status == AgentStatus.llm_new_token: text += d.choices[0].delta.content or "" diff --git a/chatchat-server/chatchat/webui_pages/utils.py b/chatchat-server/chatchat/webui_pages/utils.py index 53fdb2fc..43836077 100644 --- a/chatchat-server/chatchat/webui_pages/utils.py +++ b/chatchat-server/chatchat/webui_pages/utils.py @@ -50,14 +50,6 @@ class ApiRequest: timeout=self.timeout) return self._client - def _check_url(self, url: str) -> str: - ''' - 新版 httpx 强制要求 url 以 / 结尾,否则会返回 307 - ''' - if not url.endswith("/"): - url = url + "/" - return url - def get( self, url: str, @@ -66,7 +58,6 @@ class ApiRequest: stream: bool = False, **kwargs: Any, ) -> Union[httpx.Response, Iterator[httpx.Response], None]: - url = self._check_url(url) while retry > 0: try: if stream: @@ -88,7 +79,6 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: - url = self._check_url(url) while retry > 0: try: # print(kwargs) @@ -111,7 +101,6 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: - url = self._check_url(url) while retry > 0: try: if stream: From 27f0f512a37db4b816da543ca566b0e70d395f4d Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Fri, 29 Mar 2024 18:07:07 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20openai=20=E5=85=BC?= =?UTF-8?q?=E5=AE=B9=E7=9A=84=20files=20=E6=8E=A5=E5=8F=A3=20(#3573)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chatchat/server/api_server/api_schemas.py | 7 +- .../server/api_server/openai_routes.py | 105 +++++++++++++++++- 2 files changed, 107 insertions(+), 5 deletions(-) diff --git a/chatchat-server/chatchat/server/api_server/api_schemas.py b/chatchat-server/chatchat/server/api_server/api_schemas.py index 973c4709..965be90d 100644 --- a/chatchat-server/chatchat/server/api_server/api_schemas.py +++ b/chatchat-server/chatchat/server/api_server/api_schemas.py @@ -13,7 +13,7 @@ from openai.types.chat import ( ) from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE -from chatchat.server.callback_handler.agent_callback_handler import AgentStatus +from chatchat.server.callback_handler.agent_callback_handler import AgentStatus # noaq from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl from chatchat.server.utils import MsgType @@ -102,6 +102,11 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput): speed: Optional[float] = None +# class OpenAIFileInput(OpenAIBaseInput): +# file: UploadFile # FileTypes +# purpose: Literal["fine-tune", "assistants"] = "assistants" + + class OpenAIBaseOutput(BaseModel): id: Optional[str] = None content: Optional[str] = None diff --git a/chatchat-server/chatchat/server/api_server/openai_routes.py b/chatchat-server/chatchat/server/api_server/openai_routes.py index 3e2a4d5a..c9e616e2 100644 --- a/chatchat-server/chatchat/server/api_server/openai_routes.py +++ b/chatchat-server/chatchat/server/api_server/openai_routes.py @@ -1,15 +1,22 @@ from __future__ import annotations import asyncio +import base64 from contextlib import asynccontextmanager +from datetime import datetime +import os +from pathlib import Path +import shutil from typing import Dict, Tuple, AsyncGenerator, Iterable from fastapi import APIRouter, Request +from fastapi.responses import FileResponse from openai import AsyncClient +from openai.types.file_object import FileObject from sse_starlette.sse import EventSourceResponse from .api_schemas import * -from chatchat.configs import logger +from chatchat.configs import logger, BASE_TEMP_DIR from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient @@ -203,6 +210,96 @@ async def create_audio_speech( return await openai_request(client.audio.speech.create, body) -@openai_router.post("/files", deprecated="暂不支持") -async def files(): - ... +def _get_file_id( + purpose: str, + created_at: int, + filename: str, +) -> str: + today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") + return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode() + + +def _get_file_info(file_id: str) -> Dict: + splits = base64.urlsafe_b64decode(file_id).decode().split("/") + created_at = -1 + size = -1 + file_path = _get_file_path(file_id) + if os.path.isfile(file_path): + created_at = int(os.path.getmtime(file_path)) + size = os.path.getsize(file_path) + + return { + "purpose": splits[0], + "created_at": created_at, + "filename": splits[2], + "bytes": size, + } + + +def _get_file_path(file_id: str) -> str: + file_id = base64.urlsafe_b64decode(file_id).decode() + return os.path.join(BASE_TEMP_DIR, "openai_files", file_id) + + +@openai_router.post("/files") +async def files( + request: Request, + file: UploadFile, + purpose: str = "assistants", +) -> Dict: + created_at = int(datetime.now().timestamp()) + file_id = _get_file_id(purpose=purpose, created_at=created_at, filename=file.filename) + file_path = _get_file_path(file_id) + file_dir = os.path.dirname(file_path) + os.makedirs(file_dir, exist_ok=True) + with open(file_path, "wb") as fp: + shutil.copyfileobj(file.file, fp) + file.file.close() + + return dict( + id=file_id, + filename=file.filename, + bytes=file.size, + created_at=created_at, + object="file", + purpose=purpose, + ) + + +@openai_router.get("/files") +def list_files(purpose: str) -> Dict[str, List[Dict]]: + file_ids = [] + root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose + for dir, sub_dirs, files in os.walk(root_path): + dir = Path(dir).relative_to(root_path).as_posix() + for file in files: + file_id = base64.urlsafe_b64encode(f"{purpose}/{dir}/{file}".encode()).decode() + file_ids.append(file_id) + return {"data": [{**_get_file_info(x), "id":x, "object": "file"} for x in file_ids]} + + +@openai_router.get("/files/{file_id}") +def retrieve_file(file_id: str) -> Dict: + file_info = _get_file_info(file_id) + return {**file_info, "id": file_id, "object": "file"} + + +@openai_router.get("/files/{file_id}/content") +def retrieve_file_content(file_id: str) -> Dict: + file_path = _get_file_path(file_id) + return FileResponse(file_path) + + +@openai_router.delete("/files/{file_id}") +def delete_file(file_id: str) -> Dict: + file_path = _get_file_path(file_id) + deleted = False + + try: + if os.path.isfile(file_path): + os.remove(file_path) + deleted = True + except: + ... + + return {"id": file_id, "deleted": deleted, "object": "file"} From 6e9e31a32c2519559c8baf7f87caf23af0841744 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Fri, 29 Mar 2024 18:26:50 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E5=92=8C?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=A3=80=E6=9F=A5=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chatchat-server/README.md | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/chatchat-server/README.md b/chatchat-server/README.md index 2cbbd452..27ecdd18 100644 --- a/chatchat-server/README.md +++ b/chatchat-server/README.md @@ -12,12 +12,33 @@ Install Poetry: [documentation on how to install it.](https://python-poetry.org/ #### 本地开发环境安装 - 选择主项目目录 -``` +```shell cd chatchat ``` - 安装chatchat依赖(for running chatchat lint\tests): -``` +```shell poetry install --with lint,test -``` \ No newline at end of file +``` + +#### 格式化和代码检查 +在提交PR之前,请在本地运行以下命令;CI系统也会进行检查。 + +##### 代码格式化 +本项目使用ruff进行代码格式化。 + +要对某个库进行格式化,请在相应的库目录下运行相同的命令: +```shell +cd {model-providers|chatchat|chatchat-server|chatchat-frontend} +make format +``` + +此外,你可以使用format_diff命令仅对当前分支中与主分支相比已修改的文件进行格式化: + + +```shell + +make format_diff +``` +当你对项目的一部分进行了更改,并希望确保更改的部分格式正确,而不影响代码库的其他部分时,这个命令特别有用。 \ No newline at end of file