diff --git a/api.py b/api.py index c1a590bb..1ca5d75c 100644 --- a/api.py +++ b/api.py @@ -22,6 +22,7 @@ from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVI nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + class BaseResponse(BaseModel): code: int = pydantic.Field(200, description="HTTP status code") msg: str = pydantic.Field("success", description="HTTP status message") @@ -87,40 +88,8 @@ def get_vs_path(local_doc_id: str): def get_file_path(local_doc_id: str, doc_name: str): return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name) -async def single_upload_file( - file: UploadFile = File(description="A single binary file"), - knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), -): - saved_path = get_folder_path(knowledge_base_id) - if not os.path.exists(saved_path): - os.makedirs(saved_path) - file_content = await file.read() # 读取上传文件的内容 - - file_path = os.path.join(saved_path, file.filename) - if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content): - file_status = f"文件 {file.filename} 已存在。" - return BaseResponse(code=200, msg=file_status) - - with open(file_path, "wb") as f: - f.write(file_content) - - vs_path = get_vs_path(knowledge_base_id) - if os.path.exists(vs_path): - added_files = await local_doc_qa.add_files_to_knowledge_vector_store(vs_path, [file_path]) - if len(added_files) > 0: - file_status = f"文件 {file.filename} 已上传并已加载知识库,请开始提问。" - return BaseResponse(code=200, msg=file_status) - else: - vs_path, loaded_files = await local_doc_qa.init_knowledge_vector_store([file_path], vs_path) - if len(loaded_files) > 0: - file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。" - return BaseResponse(code=200, msg=file_status) - - file_status = "文件上传失败,请重新上传" - return BaseResponse(code=500, msg=file_status) - -async def upload_file( +async def upload_files( files: Annotated[ List[UploadFile], File(description="Multiple files as UploadFile") ], @@ -203,7 +172,7 @@ async def delete_docs( return BaseResponse() -async def chat( +async def local_doc_chat( knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), question: str = Body(..., description="Question", example="工伤保险是什么?"), history: List[List[str]] = Body( @@ -238,7 +207,8 @@ async def chat( source_documents=source_documents, ) -async def no_knowledge_chat( + +async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), history: List[List[str]] = Body( [], @@ -251,7 +221,6 @@ async def no_knowledge_chat( ], ), ): - for resp, history in local_doc_qa.llm._call( prompt=question, history=history, streaming=True ): @@ -264,6 +233,7 @@ async def no_knowledge_chat( source_documents=[], ) + async def stream_chat(websocket: WebSocket, knowledge_base_id: str): await websocket.accept() vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id) @@ -329,16 +299,19 @@ def main(): allow_credentials=True, allow_methods=["*"], allow_headers=["*"], - ) - app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat) - app.post("/chat-docs/chat", response_model=ChatMessage)(chat) - app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat) - app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file) - app.post("/chat-docs/uploadone", response_model=BaseResponse)(single_upload_file) - app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs) - app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs) + ) + app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat) + app.get("/", response_model=BaseResponse)(document) + app.post("/chat", response_model=ChatMessage)(chat) + + app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files) + app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat) + app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) + app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs) + + local_doc_qa = LocalDocQA() local_doc_qa.init_cfg( llm_model=LLM_MODEL,