mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-08 16:10:18 +08:00
update api.py
This commit is contained in:
parent
5575079e6b
commit
a8ac40163c
61
api.py
61
api.py
@ -22,6 +22,7 @@ from configs.model_config import (VS_ROOT_PATH, UPLOAD_ROOT_PATH, EMBEDDING_DEVI
|
|||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
class BaseResponse(BaseModel):
|
||||||
code: int = pydantic.Field(200, description="HTTP status code")
|
code: int = pydantic.Field(200, description="HTTP status code")
|
||||||
msg: str = pydantic.Field("success", description="HTTP status message")
|
msg: str = pydantic.Field("success", description="HTTP status message")
|
||||||
@ -87,40 +88,8 @@ def get_vs_path(local_doc_id: str):
|
|||||||
def get_file_path(local_doc_id: str, doc_name: str):
|
def get_file_path(local_doc_id: str, doc_name: str):
|
||||||
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
||||||
|
|
||||||
async def single_upload_file(
|
|
||||||
file: UploadFile = File(description="A single binary file"),
|
|
||||||
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
|
||||||
):
|
|
||||||
saved_path = get_folder_path(knowledge_base_id)
|
|
||||||
if not os.path.exists(saved_path):
|
|
||||||
os.makedirs(saved_path)
|
|
||||||
|
|
||||||
file_content = await file.read() # 读取上传文件的内容
|
async def upload_files(
|
||||||
|
|
||||||
file_path = os.path.join(saved_path, file.filename)
|
|
||||||
if os.path.exists(file_path) and os.path.getsize(file_path) == len(file_content):
|
|
||||||
file_status = f"文件 {file.filename} 已存在。"
|
|
||||||
return BaseResponse(code=200, msg=file_status)
|
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(file_content)
|
|
||||||
|
|
||||||
vs_path = get_vs_path(knowledge_base_id)
|
|
||||||
if os.path.exists(vs_path):
|
|
||||||
added_files = await local_doc_qa.add_files_to_knowledge_vector_store(vs_path, [file_path])
|
|
||||||
if len(added_files) > 0:
|
|
||||||
file_status = f"文件 {file.filename} 已上传并已加载知识库,请开始提问。"
|
|
||||||
return BaseResponse(code=200, msg=file_status)
|
|
||||||
else:
|
|
||||||
vs_path, loaded_files = await local_doc_qa.init_knowledge_vector_store([file_path], vs_path)
|
|
||||||
if len(loaded_files) > 0:
|
|
||||||
file_status = f"文件 {file.filename} 已上传至新的知识库,并已加载知识库,请开始提问。"
|
|
||||||
return BaseResponse(code=200, msg=file_status)
|
|
||||||
|
|
||||||
file_status = "文件上传失败,请重新上传"
|
|
||||||
return BaseResponse(code=500, msg=file_status)
|
|
||||||
|
|
||||||
async def upload_file(
|
|
||||||
files: Annotated[
|
files: Annotated[
|
||||||
List[UploadFile], File(description="Multiple files as UploadFile")
|
List[UploadFile], File(description="Multiple files as UploadFile")
|
||||||
],
|
],
|
||||||
@ -203,7 +172,7 @@ async def delete_docs(
|
|||||||
return BaseResponse()
|
return BaseResponse()
|
||||||
|
|
||||||
|
|
||||||
async def chat(
|
async def local_doc_chat(
|
||||||
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: List[List[str]] = Body(
|
history: List[List[str]] = Body(
|
||||||
@ -238,7 +207,8 @@ async def chat(
|
|||||||
source_documents=source_documents,
|
source_documents=source_documents,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def no_knowledge_chat(
|
|
||||||
|
async def chat(
|
||||||
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
question: str = Body(..., description="Question", example="工伤保险是什么?"),
|
||||||
history: List[List[str]] = Body(
|
history: List[List[str]] = Body(
|
||||||
[],
|
[],
|
||||||
@ -251,7 +221,6 @@ async def no_knowledge_chat(
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
|
||||||
for resp, history in local_doc_qa.llm._call(
|
for resp, history in local_doc_qa.llm._call(
|
||||||
prompt=question, history=history, streaming=True
|
prompt=question, history=history, streaming=True
|
||||||
):
|
):
|
||||||
@ -264,6 +233,7 @@ async def no_knowledge_chat(
|
|||||||
source_documents=[],
|
source_documents=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
async def stream_chat(websocket: WebSocket, knowledge_base_id: str):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
vs_path = os.path.join(VS_ROOT_PATH, knowledge_base_id)
|
||||||
@ -329,16 +299,19 @@ def main():
|
|||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
|
app.websocket("/local_doc_qa/stream-chat/{knowledge_base_id}")(stream_chat)
|
||||||
app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
|
|
||||||
app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat)
|
|
||||||
app.post("/chat-docs/upload", response_model=BaseResponse)(upload_file)
|
|
||||||
app.post("/chat-docs/uploadone", response_model=BaseResponse)(single_upload_file)
|
|
||||||
app.get("/chat-docs/list", response_model=ListDocsResponse)(list_docs)
|
|
||||||
app.delete("/chat-docs/delete", response_model=BaseResponse)(delete_docs)
|
|
||||||
app.get("/", response_model=BaseResponse)(document)
|
app.get("/", response_model=BaseResponse)(document)
|
||||||
|
|
||||||
|
app.post("/chat", response_model=ChatMessage)(chat)
|
||||||
|
|
||||||
|
app.post("/local_doc_qa/upload_files", response_model=BaseResponse)(upload_files)
|
||||||
|
app.post("/local_doc_qa/local_doc_chat", response_model=ChatMessage)(local_doc_chat)
|
||||||
|
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
||||||
|
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
|
||||||
|
|
||||||
|
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
local_doc_qa.init_cfg(
|
local_doc_qa.init_cfg(
|
||||||
llm_model=LLM_MODEL,
|
llm_model=LLM_MODEL,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user