diff --git a/chatchat-server/README.md b/chatchat-server/README.md index 2cbbd452..27ecdd18 100644 --- a/chatchat-server/README.md +++ b/chatchat-server/README.md @@ -12,12 +12,33 @@ Install Poetry: [documentation on how to install it.](https://python-poetry.org/ #### 本地开发环境安装 - 选择主项目目录 -``` +```shell cd chatchat ``` - 安装chatchat依赖(for running chatchat lint\tests): -``` +```shell poetry install --with lint,test -``` \ No newline at end of file +``` + +#### 格式化和代码检查 +在提交PR之前,请在本地运行以下命令;CI系统也会进行检查。 + +##### 代码格式化 +本项目使用ruff进行代码格式化。 + +要对某个库进行格式化,请在相应的库目录下运行相同的命令: +```shell +cd {model-providers|chatchat|chatchat-server|chatchat-frontend} +make format +``` + +此外,你可以使用format_diff命令仅对当前分支中与主分支相比已修改的文件进行格式化: + + +```shell + +make format_diff +``` +当你对项目的一部分进行了更改,并希望确保更改的部分格式正确,而不影响代码库的其他部分时,这个命令特别有用。 \ No newline at end of file diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index 1f83d94a..a34e8301 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -52,4 +52,4 @@ def search_local_knowledgebase( '''''' tool_config = get_tool_config("search_local_knowledgebase") ret = search_knowledgebase(query=query, database=database, config=tool_config) - return BaseToolOutput(ret, database=database) + return KBToolOutput(ret, database=database) diff --git a/chatchat-server/chatchat/server/api_server/api_schemas.py b/chatchat-server/chatchat/server/api_server/api_schemas.py index 973c4709..965be90d 100644 --- a/chatchat-server/chatchat/server/api_server/api_schemas.py +++ b/chatchat-server/chatchat/server/api_server/api_schemas.py @@ -13,7 +13,7 @@ from openai.types.chat import ( ) from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE -from chatchat.server.callback_handler.agent_callback_handler import AgentStatus +from chatchat.server.callback_handler.agent_callback_handler import AgentStatus # noaq from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl from chatchat.server.utils import MsgType @@ -102,6 +102,11 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput): speed: Optional[float] = None +# class OpenAIFileInput(OpenAIBaseInput): +# file: UploadFile # FileTypes +# purpose: Literal["fine-tune", "assistants"] = "assistants" + + class OpenAIBaseOutput(BaseModel): id: Optional[str] = None content: Optional[str] = None diff --git a/chatchat-server/chatchat/server/api_server/chat_routes.py b/chatchat-server/chatchat/server/api_server/chat_routes.py index 11a561ff..f935ab9c 100644 --- a/chatchat-server/chatchat/server/api_server/chat_routes.py +++ b/chatchat-server/chatchat/server/api_server/chat_routes.py @@ -99,7 +99,7 @@ async def chat_completions( "status": None, } header = [{**extra_json, - "content": f"知识库参考资料:\n\n{tool_result}\n\n", + "content": f"{tool_result}", "tool_output":tool_result.data, "is_ref": True, }] diff --git a/chatchat-server/chatchat/server/api_server/openai_routes.py b/chatchat-server/chatchat/server/api_server/openai_routes.py index 3e2a4d5a..c9e616e2 100644 --- a/chatchat-server/chatchat/server/api_server/openai_routes.py +++ b/chatchat-server/chatchat/server/api_server/openai_routes.py @@ -1,15 +1,22 @@ from __future__ import annotations import asyncio +import base64 from contextlib import asynccontextmanager +from datetime import datetime +import os +from pathlib import Path +import shutil from typing import Dict, Tuple, AsyncGenerator, Iterable from fastapi import APIRouter, Request +from fastapi.responses import FileResponse from openai import AsyncClient +from openai.types.file_object import FileObject from sse_starlette.sse import EventSourceResponse from .api_schemas import * -from chatchat.configs import logger +from chatchat.configs import logger, BASE_TEMP_DIR from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient @@ -203,6 +210,96 @@ async def create_audio_speech( return await openai_request(client.audio.speech.create, body) -@openai_router.post("/files", deprecated="暂不支持") -async def files(): - ... +def _get_file_id( + purpose: str, + created_at: int, + filename: str, +) -> str: + today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") + return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode() + + +def _get_file_info(file_id: str) -> Dict: + splits = base64.urlsafe_b64decode(file_id).decode().split("/") + created_at = -1 + size = -1 + file_path = _get_file_path(file_id) + if os.path.isfile(file_path): + created_at = int(os.path.getmtime(file_path)) + size = os.path.getsize(file_path) + + return { + "purpose": splits[0], + "created_at": created_at, + "filename": splits[2], + "bytes": size, + } + + +def _get_file_path(file_id: str) -> str: + file_id = base64.urlsafe_b64decode(file_id).decode() + return os.path.join(BASE_TEMP_DIR, "openai_files", file_id) + + +@openai_router.post("/files") +async def files( + request: Request, + file: UploadFile, + purpose: str = "assistants", +) -> Dict: + created_at = int(datetime.now().timestamp()) + file_id = _get_file_id(purpose=purpose, created_at=created_at, filename=file.filename) + file_path = _get_file_path(file_id) + file_dir = os.path.dirname(file_path) + os.makedirs(file_dir, exist_ok=True) + with open(file_path, "wb") as fp: + shutil.copyfileobj(file.file, fp) + file.file.close() + + return dict( + id=file_id, + filename=file.filename, + bytes=file.size, + created_at=created_at, + object="file", + purpose=purpose, + ) + + +@openai_router.get("/files") +def list_files(purpose: str) -> Dict[str, List[Dict]]: + file_ids = [] + root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose + for dir, sub_dirs, files in os.walk(root_path): + dir = Path(dir).relative_to(root_path).as_posix() + for file in files: + file_id = base64.urlsafe_b64encode(f"{purpose}/{dir}/{file}".encode()).decode() + file_ids.append(file_id) + return {"data": [{**_get_file_info(x), "id":x, "object": "file"} for x in file_ids]} + + +@openai_router.get("/files/{file_id}") +def retrieve_file(file_id: str) -> Dict: + file_info = _get_file_info(file_id) + return {**file_info, "id": file_id, "object": "file"} + + +@openai_router.get("/files/{file_id}/content") +def retrieve_file_content(file_id: str) -> Dict: + file_path = _get_file_path(file_id) + return FileResponse(file_path) + + +@openai_router.delete("/files/{file_id}") +def delete_file(file_id: str) -> Dict: + file_path = _get_file_path(file_id) + deleted = False + + try: + if os.path.isfile(file_path): + os.remove(file_path) + deleted = True + except: + ... + + return {"id": file_id, "deleted": deleted, "object": "file"} diff --git a/chatchat-server/chatchat/server/api_server/tool_routes.py b/chatchat-server/chatchat/server/api_server/tool_routes.py index 7355d6b8..3093b40e 100644 --- a/chatchat-server/chatchat/server/api_server/tool_routes.py +++ b/chatchat-server/chatchat/server/api_server/tool_routes.py @@ -11,7 +11,7 @@ from chatchat.server.utils import BaseResponse, get_tool, get_tool_config tool_router = APIRouter(prefix="/tools", tags=["Toolkits"]) -@tool_router.get("/", response_model=BaseResponse) +@tool_router.get("", response_model=BaseResponse) async def list_tools(): tools = get_tool() data = {t.name: { diff --git a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index 4a9b3efb..2a12e3ee 100644 --- a/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -254,19 +254,21 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): tool_choice=tool_choice, extra_body=extra_body, ): - print("\n\n", d.status, "\n", d, "\n\n") + # print("\n\n", d.status, "\n", d, "\n\n") message_id = d.message_id metadata = { "message_id": message_id, } + # clear initial message + if not started: + chat_box.update_msg("", streaming=False) + started = True + if d.status == AgentStatus.error: st.error(d.choices[0].delta.content) elif d.status == AgentStatus.llm_start: - if not started: - started = True - else: - chat_box.insert_msg("正在解读工具输出结果...") + chat_box.insert_msg("正在解读工具输出结果...") text = d.choices[0].delta.content or "" elif d.status == AgentStatus.llm_new_token: text += d.choices[0].delta.content or "" diff --git a/chatchat-server/chatchat/webui_pages/utils.py b/chatchat-server/chatchat/webui_pages/utils.py index 53fdb2fc..43836077 100644 --- a/chatchat-server/chatchat/webui_pages/utils.py +++ b/chatchat-server/chatchat/webui_pages/utils.py @@ -50,14 +50,6 @@ class ApiRequest: timeout=self.timeout) return self._client - def _check_url(self, url: str) -> str: - ''' - 新版 httpx 强制要求 url 以 / 结尾,否则会返回 307 - ''' - if not url.endswith("/"): - url = url + "/" - return url - def get( self, url: str, @@ -66,7 +58,6 @@ class ApiRequest: stream: bool = False, **kwargs: Any, ) -> Union[httpx.Response, Iterator[httpx.Response], None]: - url = self._check_url(url) while retry > 0: try: if stream: @@ -88,7 +79,6 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: - url = self._check_url(url) while retry > 0: try: # print(kwargs) @@ -111,7 +101,6 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, Iterator[httpx.Response], None]: - url = self._check_url(url) while retry > 0: try: if stream: