diff --git a/requirements.txt b/requirements.txt index b5cd996d..b8633d82 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ torch==2.1.0 ##on Windows system, install the cuda version manually from https: torchvision #on Windows system, install the cuda version manually from https://pytorch.org/ torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/ fastapi>=0.104 +sse_starlette nltk>=3.8.1 uvicorn>=0.24.0.post1 starlette~=0.27.0 diff --git a/requirements_api.txt b/requirements_api.txt index ec1005f9..e5d77b64 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -12,6 +12,7 @@ torch==2.1.0 ##on Windows system, install the cuda version manually from https: torchvision #on Windows system, install the cuda version manually from https://pytorch.org/ torchaudio #on Windows system, install the cuda version manually from https://pytorch.org/ fastapi>=0.104 +sse_starlette nltk>=3.8.1 uvicorn>=0.24.0.post1 starlette~=0.27.0 diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index a3ab68b1..c37e030b 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,5 +1,5 @@ from fastapi import Body, Request -from fastapi.responses import StreamingResponse +from sse_starlette.sse import EventSourceResponse from fastapi.concurrency import run_in_threadpool from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) from server.utils import wrap_done, get_ChatOpenAI @@ -119,9 +119,4 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", ensure_ascii=False) await task - return StreamingResponse(knowledge_base_chat_iterator(query=query, - top_k=top_k, - history=history, - model_name=model_name, - prompt_name=prompt_name), - media_type="text/event-stream") + return EventSourceResponse(knowledge_base_chat_iterator(query, kb, top_k, history))