mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
添加 openai 兼容的 files 接口 (#3573)
This commit is contained in:
parent
a1429a350a
commit
27f0f512a3
@ -13,7 +13,7 @@ from openai.types.chat import (
|
||||
)
|
||||
|
||||
from chatchat.configs import DEFAULT_LLM_MODEL, TEMPERATURE
|
||||
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
|
||||
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus # noaq
|
||||
from chatchat.server.pydantic_v2 import BaseModel, Field, AnyUrl
|
||||
from chatchat.server.utils import MsgType
|
||||
|
||||
@ -102,6 +102,11 @@ class OpenAIAudioSpeechInput(OpenAIBaseInput):
|
||||
speed: Optional[float] = None
|
||||
|
||||
|
||||
# class OpenAIFileInput(OpenAIBaseInput):
|
||||
# file: UploadFile # FileTypes
|
||||
# purpose: Literal["fine-tune", "assistants"] = "assistants"
|
||||
|
||||
|
||||
class OpenAIBaseOutput(BaseModel):
|
||||
id: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
@ -1,15 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Dict, Tuple, AsyncGenerator, Iterable
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from openai import AsyncClient
|
||||
from openai.types.file_object import FileObject
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from .api_schemas import *
|
||||
from chatchat.configs import logger
|
||||
from chatchat.configs import logger, BASE_TEMP_DIR
|
||||
from chatchat.server.utils import get_model_info, get_config_platforms, get_OpenAIClient
|
||||
|
||||
|
||||
@ -203,6 +210,96 @@ async def create_audio_speech(
|
||||
return await openai_request(client.audio.speech.create, body)
|
||||
|
||||
|
||||
@openai_router.post("/files", deprecated="暂不支持")
|
||||
async def files():
|
||||
...
|
||||
def _get_file_id(
|
||||
purpose: str,
|
||||
created_at: int,
|
||||
filename: str,
|
||||
) -> str:
|
||||
today = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
|
||||
return base64.urlsafe_b64encode(f"{purpose}/{today}/{filename}".encode()).decode()
|
||||
|
||||
|
||||
def _get_file_info(file_id: str) -> Dict:
|
||||
splits = base64.urlsafe_b64decode(file_id).decode().split("/")
|
||||
created_at = -1
|
||||
size = -1
|
||||
file_path = _get_file_path(file_id)
|
||||
if os.path.isfile(file_path):
|
||||
created_at = int(os.path.getmtime(file_path))
|
||||
size = os.path.getsize(file_path)
|
||||
|
||||
return {
|
||||
"purpose": splits[0],
|
||||
"created_at": created_at,
|
||||
"filename": splits[2],
|
||||
"bytes": size,
|
||||
}
|
||||
|
||||
|
||||
def _get_file_path(file_id: str) -> str:
|
||||
file_id = base64.urlsafe_b64decode(file_id).decode()
|
||||
return os.path.join(BASE_TEMP_DIR, "openai_files", file_id)
|
||||
|
||||
|
||||
@openai_router.post("/files")
|
||||
async def files(
|
||||
request: Request,
|
||||
file: UploadFile,
|
||||
purpose: str = "assistants",
|
||||
) -> Dict:
|
||||
created_at = int(datetime.now().timestamp())
|
||||
file_id = _get_file_id(purpose=purpose, created_at=created_at, filename=file.filename)
|
||||
file_path = _get_file_path(file_id)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
with open(file_path, "wb") as fp:
|
||||
shutil.copyfileobj(file.file, fp)
|
||||
file.file.close()
|
||||
|
||||
return dict(
|
||||
id=file_id,
|
||||
filename=file.filename,
|
||||
bytes=file.size,
|
||||
created_at=created_at,
|
||||
object="file",
|
||||
purpose=purpose,
|
||||
)
|
||||
|
||||
|
||||
@openai_router.get("/files")
|
||||
def list_files(purpose: str) -> Dict[str, List[Dict]]:
|
||||
file_ids = []
|
||||
root_path = Path(BASE_TEMP_DIR) / "openai_files" / purpose
|
||||
for dir, sub_dirs, files in os.walk(root_path):
|
||||
dir = Path(dir).relative_to(root_path).as_posix()
|
||||
for file in files:
|
||||
file_id = base64.urlsafe_b64encode(f"{purpose}/{dir}/{file}".encode()).decode()
|
||||
file_ids.append(file_id)
|
||||
return {"data": [{**_get_file_info(x), "id":x, "object": "file"} for x in file_ids]}
|
||||
|
||||
|
||||
@openai_router.get("/files/{file_id}")
|
||||
def retrieve_file(file_id: str) -> Dict:
|
||||
file_info = _get_file_info(file_id)
|
||||
return {**file_info, "id": file_id, "object": "file"}
|
||||
|
||||
|
||||
@openai_router.get("/files/{file_id}/content")
|
||||
def retrieve_file_content(file_id: str) -> Dict:
|
||||
file_path = _get_file_path(file_id)
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
@openai_router.delete("/files/{file_id}")
|
||||
def delete_file(file_id: str) -> Dict:
|
||||
file_path = _get_file_path(file_id)
|
||||
deleted = False
|
||||
|
||||
try:
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
deleted = True
|
||||
except:
|
||||
...
|
||||
|
||||
return {"id": file_id, "deleted": deleted, "object": "file"}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user