添加 openai 兼容的 files 接口 (#3573)

This commit is contained in:
liunux4odoo 2024-03-29 18:07:07 +08:00 committed by GitHub
parent a1429a350a
commit 27f0f512a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 107 additions and 5 deletions

View File

@ -13,7 +13,7 @@ from openai.types.chat import (
)
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus # noaq
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl
from chatchat.server.utils import MsgType
@ -102,6 +102,11 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput):
speed: Optional[float] = None
# class OpenAIFileInput(OpenAIBaseInput):
# file: UploadFile # FileTypes
# purpose: Literal["fine-tune", "assistants"] = "assistants"
class OpenAIBaseOutput(BaseModel):
id: Optional[str] = None
content: Optional[str] = None

View File

@ -1,15 +1,22 @@
from __future__ import annotations
import asyncio
import base64
from contextlib import asynccontextmanager
from datetime import datetime
import os
from pathlib import Path
import shutil
from typing import Dict, Tuple, AsyncGenerator, Iterable
from fastapi import APIRouter, Request
from fastapi.responses import FileResponse
from openai import AsyncClient
from openai.types.file_object import FileObject
from sse_starlette.sse import EventSourceResponse
from .api_schemas import *
from chatchat.configs import logger
from chatchat.configs import logger, BASE_TEMP_DIR
from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient
@ -203,6 +210,96 @@ async def create_audio_speech(
return await openai_request(client.audio.speech.create, body)
@openai_router.post("/files", deprecated="暂不支持")
async def files():
...
def _get_file_id(
purpose: str,
created_at: int,
filename: str,
) -> str:
today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode()
def _get_file_info(file_id: str) -> Dict:
splits = base64.urlsafe_b64decode(file_id).decode().split("/")
created_at = -1
size = -1
file_path = _get_file_path(file_id)
if os.path.isfile(file_path):
created_at = int(os.path.getmtime(file_path))
size = os.path.getsize(file_path)
return {
"purpose": splits[0],
"created_at": created_at,
"filename": splits[2],
"bytes": size,
}
def _get_file_path(file_id: str) -> str:
file_id = base64.urlsafe_b64decode(file_id).decode()
return os.path.join(BASE_TEMP_DIR, "openai_files", file_id)
@openai_router.post("/files")
async def files(
request: Request,
file: UploadFile,
purpose: str = "assistants",
) -> Dict:
created_at = int(datetime.now().timestamp())
file_id = _get_file_id(purpose=purpose, created_at=created_at, filename=file.filename)
file_path = _get_file_path(file_id)
file_dir = os.path.dirname(file_path)
os.makedirs(file_dir, exist_ok=True)
with open(file_path, "wb") as fp:
shutil.copyfileobj(file.file, fp)
file.file.close()
return dict(
id=file_id,
filename=file.filename,
bytes=file.size,
created_at=created_at,
object="file",
purpose=purpose,
)
@openai_router.get("/files")
def list_files(purpose: str) -> Dict[str, List[Dict]]:
file_ids = []
root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose
for dir, sub_dirs, files in os.walk(root_path):
dir = Path(dir).relative_to(root_path).as_posix()
for file in files:
file_id = base64.urlsafe_b64encode(f"{purpose}/{dir}/{file}".encode()).decode()
file_ids.append(file_id)
return {"data": [{**_get_file_info(x), "id":x, "object": "file"} for x in file_ids]}
@openai_router.get("/files/{file_id}")
def retrieve_file(file_id: str) -> Dict:
file_info = _get_file_info(file_id)
return {**file_info, "id": file_id, "object": "file"}
@openai_router.get("/files/{file_id}/content")
def retrieve_file_content(file_id: str) -> Dict:
file_path = _get_file_path(file_id)
return FileResponse(file_path)
@openai_router.delete("/files/{file_id}")
def delete_file(file_id: str) -> Dict:
file_path = _get_file_path(file_id)
deleted = False
try:
if os.path.isfile(file_path):
os.remove(file_path)
deleted = True
except:
...
return {"id": file_id, "deleted": deleted, "object": "file"}