Merge branch 'dev' into dev_model_providers

This commit is contained in:
glide-the 2024-03-31 15:08:56 +08:00
commit b8d748b668
8 changed files with 141 additions and 27 deletions

View File

@ -12,12 +12,33 @@ Install Poetry: [documentation on how to install it.](https://python-poetry.org/
#### 本地开发环境安装
- 选择主项目目录
```
```shell
cd chatchat
```
- 安装chatchat依赖(for running chatchat lint\tests):
```
```shell
poetry install --with lint,test
```
```
#### 格式化和代码检查
在提交PR之前,请在本地运行以下命令;CI系统也会进行检查。
##### 代码格式化
本项目使用ruff进行代码格式化。
要对某个库进行格式化,请在相应的库目录下运行相同的命令:
```shell
cd {model-providers|chatchat|chatchat-server|chatchat-frontend}
make format
```
此外,你可以使用format_diff命令仅对当前分支中与主分支相比已修改的文件进行格式化:
```shell
make format_diff
```
当你对项目的一部分进行了更改,并希望确保更改的部分格式正确,而不影响代码库的其他部分时,这个命令特别有用。

View File

@ -52,4 +52,4 @@ def search_local_knowledgebase(
''''''
tool_config = get_tool_config("search_local_knowledgebase")
ret = search_knowledgebase(query=query, database=database, config=tool_config)
return BaseToolOutput(ret, database=database)
return KBToolOutput(ret, database=database)

View File

@ -13,7 +13,7 @@ from openai.types.chat import (
)
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus # noaq
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl
from chatchat.server.utils import MsgType
@ -102,6 +102,11 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput):
speed: Optional[float] = None
# class OpenAIFileInput(OpenAIBaseInput):
# file: UploadFile # FileTypes
# purpose: Literal["fine-tune", "assistants"] = "assistants"
class OpenAIBaseOutput(BaseModel):
id: Optional[str] = None
content: Optional[str] = None

View File

@ -99,7 +99,7 @@ async def chat_completions(
"status": None,
}
header = [{**extra_json,
"content": f"知识库参考资料:\n\n{tool_result}\n\n",
"content": f"{tool_result}",
"tool_output":tool_result.data,
"is_ref": True,
}]

View File

@ -1,15 +1,22 @@
from __future__ import annotations
import asyncio
import base64
from contextlib import asynccontextmanager
from datetime import datetime
import os
from pathlib import Path
import shutil
from typing import Dict, Tuple, AsyncGenerator, Iterable
from fastapi import APIRouter, Request
from fastapi.responses import FileResponse
from openai import AsyncClient
from openai.types.file_object import FileObject
from sse_starlette.sse import EventSourceResponse
from .api_schemas import *
from chatchat.configs import logger
from chatchat.configs import logger, BASE_TEMP_DIR
from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient
@ -203,6 +210,96 @@ async def create_audio_speech(
return await openai_request(client.audio.speech.create, body)
@openai_router.post("/files", deprecated="暂不支持")
async def files():
...
def _get_file_id(
purpose: str,
created_at: int,
filename: str,
) -> str:
today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode()
def _get_file_info(file_id: str) -> Dict:
splits = base64.urlsafe_b64decode(file_id).decode().split("/")
created_at = -1
size = -1
file_path = _get_file_path(file_id)
if os.path.isfile(file_path):
created_at = int(os.path.getmtime(file_path))
size = os.path.getsize(file_path)
return {
"purpose": splits[0],
"created_at": created_at,
"filename": splits[2],
"bytes": size,
}
def _get_file_path(file_id: str) -> str:
file_id = base64.urlsafe_b64decode(file_id).decode()
return os.path.join(BASE_TEMP_DIR, "openai_files", file_id)
@openai_router.post("/files")
async def files(
request: Request,
file: UploadFile,
purpose: str = "assistants",
) -> Dict:
created_at = int(datetime.now().timestamp())
file_id = _get_file_id(purpose=purpose, created_at=created_at, filename=file.filename)
file_path = _get_file_path(file_id)
file_dir = os.path.dirname(file_path)
os.makedirs(file_dir, exist_ok=True)
with open(file_path, "wb") as fp:
shutil.copyfileobj(file.file, fp)
file.file.close()
return dict(
id=file_id,
filename=file.filename,
bytes=file.size,
created_at=created_at,
object="file",
purpose=purpose,
)
@openai_router.get("/files")
def list_files(purpose: str) -> Dict[str, List[Dict]]:
file_ids = []
root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose
for dir, sub_dirs, files in os.walk(root_path):
dir = Path(dir).relative_to(root_path).as_posix()
for file in files:
file_id = base64.urlsafe_b64encode(f"{purpose}/{dir}/{file}".encode()).decode()
file_ids.append(file_id)
return {"data": [{**_get_file_info(x), "id":x, "object": "file"} for x in file_ids]}
@openai_router.get("/files/{file_id}")
def retrieve_file(file_id: str) -> Dict:
file_info = _get_file_info(file_id)
return {**file_info, "id": file_id, "object": "file"}
@openai_router.get("/files/{file_id}/content")
def retrieve_file_content(file_id: str) -> Dict:
file_path = _get_file_path(file_id)
return FileResponse(file_path)
@openai_router.delete("/files/{file_id}")
def delete_file(file_id: str) -> Dict:
file_path = _get_file_path(file_id)
deleted = False
try:
if os.path.isfile(file_path):
os.remove(file_path)
deleted = True
except:
...
return {"id": file_id, "deleted": deleted, "object": "file"}

View File

@ -11,7 +11,7 @@ from chatchat.server.utils import BaseResponse, get_tool, get_tool_config
tool_router = APIRouter(prefix="/tools", tags=["Toolkits"])
@tool_router.get("/", response_model=BaseResponse)
@tool_router.get("", response_model=BaseResponse)
async def list_tools():
tools = get_tool()
data = {t.name: {

View File

@ -254,19 +254,21 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
tool_choice=tool_choice,
extra_body=extra_body,
):
print("\n\n", d.status, "\n", d, "\n\n")
# print("\n\n", d.status, "\n", d, "\n\n")
message_id = d.message_id
metadata = {
"message_id": message_id,
}
# clear initial message
if not started:
chat_box.update_msg("", streaming=False)
started = True
if d.status == AgentStatus.error:
st.error(d.choices[0].delta.content)
elif d.status == AgentStatus.llm_start:
if not started:
started = True
else:
chat_box.insert_msg("正在解读工具输出结果...")
chat_box.insert_msg("正在解读工具输出结果...")
text = d.choices[0].delta.content or ""
elif d.status == AgentStatus.llm_new_token:
text += d.choices[0].delta.content or ""

View File

@ -50,14 +50,6 @@ class ApiRequest:
timeout=self.timeout)
return self._client
def _check_url(self, url: str) -> str:
'''
新版 httpx 强制要求 url / 结尾否则会返回 307
'''
if not url.endswith("/"):
url = url + "/"
return url
def get(
self,
url: str,
@ -66,7 +58,6 @@ class ApiRequest:
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0:
try:
if stream:
@ -88,7 +79,6 @@ class ApiRequest:
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0:
try:
# print(kwargs)
@ -111,7 +101,6 @@ class ApiRequest:
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
url = self._check_url(url)
while retry > 0:
try:
if stream: