From 63cbd0b9bb6b4104e4e5914eba8929c33e668cdb Mon Sep 17 00:00:00 2001 From: leehk Date: Thu, 17 Apr 2025 13:11:27 +0800 Subject: [PATCH] update --- app/backend/api/chatbot.py | 38 +++++++++++++--------- app/backend/api/utils.py | 1 - app/backend/config.py | 2 +- app/backend/main.py | 12 ++----- app/backend/models/adaptive_rag/grading.py | 4 ++- app/backend/models/adaptive_rag/query.py | 2 +- app/backend/models/adaptive_rag/routing.py | 2 +- app/backend/setup.cfg | 2 ++ app/backend/tests/api/test_chatbot.py | 11 ------- app/backend/tests/api/test_utils.py | 18 ++++++---- 10 files changed, 45 insertions(+), 47 deletions(-) create mode 100644 app/backend/setup.cfg delete mode 100644 app/backend/tests/api/test_chatbot.py diff --git a/app/backend/api/chatbot.py b/app/backend/api/chatbot.py index 7eaf282..97ed103 100644 --- a/app/backend/api/chatbot.py +++ b/app/backend/api/chatbot.py @@ -5,8 +5,6 @@ from decouple import config from fastapi import APIRouter, WebSocket, WebSocketDisconnect from langchain_deepseek import ChatDeepSeek -from models.adaptive_rag import grading, query, routing - from .utils import ConnectionManager router = APIRouter() @@ -17,7 +15,7 @@ os.environ["TAVILY_API_KEY"] = config("TAVILY_API_KEY", cast=str) # Initialize the DeepSeek chat model llm_chat = ChatDeepSeek( - model="deepseek-chat", + model="deepseek-chat", temperature=0, max_tokens=None, timeout=None, @@ -27,28 +25,36 @@ llm_chat = ChatDeepSeek( # Initialize the connection manager manager = ConnectionManager() + @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) try: while True: - data = await websocket.receive_text() - - try: - data_json = json.loads(data) - if isinstance(data_json, list) and len(data_json) > 0 and 'content' in data_json[0]: - async for chunk in llm_chat.astream(data_json[0]['content']): - await manager.send_personal_message(json.dumps({"type": "message", "payload": chunk.content}), websocket) - else: - await manager.send_personal_message("Invalid message format", websocket) + data = await websocket.receive_text() - except json.JSONDecodeError: - await manager.broadcast("Invalid JSON message") + try: + data_json = json.loads(data) + if ( + isinstance(data_json, list) + and len(data_json) > 0 + and "content" in data_json[0] + ): + async for chunk in llm_chat.astream(data_json[0]["content"]): + await manager.send_personal_message( + json.dumps({"type": "message", "payload": chunk.content}), + websocket, + ) + else: + await manager.send_personal_message( + "Invalid message format", websocket + ) + + except json.JSONDecodeError: + await manager.broadcast("Invalid JSON message") except WebSocketDisconnect: manager.disconnect(websocket) await manager.broadcast("Client disconnected") except WebSocketDisconnect: manager.disconnect(websocket) await manager.broadcast("Client disconnected") - - diff --git a/app/backend/api/utils.py b/app/backend/api/utils.py index 54767aa..a58c747 100644 --- a/app/backend/api/utils.py +++ b/app/backend/api/utils.py @@ -22,4 +22,3 @@ class ConnectionManager: json_message = {"type": "message", "payload": message} for connection in self.active_connections: await connection.send_text(json.dumps(json_message)) - diff --git a/app/backend/config.py b/app/backend/config.py index 32ef3d6..a5f6943 100644 --- a/app/backend/config.py +++ b/app/backend/config.py @@ -14,4 +14,4 @@ class Settings(BaseSettings): @lru_cache() def get_settings() -> BaseSettings: log.info("Loading config settings from the environment...") - return Settings() \ No newline at end of file + return Settings() diff --git a/app/backend/main.py b/app/backend/main.py index b12cc4f..59a6ea7 100644 --- a/app/backend/main.py +++ b/app/backend/main.py @@ -1,21 +1,19 @@ import logging -import uvicorn -from fastapi import Depends, FastAPI +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from api import chatbot, ping -from config import Settings, get_settings log = logging.getLogger("uvicorn") origins = ["http://localhost:8004"] + def create_application() -> FastAPI: application = FastAPI() application.include_router(ping.router, tags=["ping"]) - application.include_router( - chatbot.router, tags=["chatbot"]) + application.include_router(chatbot.router, tags=["chatbot"]) return application @@ -28,7 +26,3 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) - - -# if __name__ == "__main__": -# uvicorn.run("main:app", host="0.0.0.0", port=80, reload=True) \ No newline at end of file diff --git a/app/backend/models/adaptive_rag/grading.py b/app/backend/models/adaptive_rag/grading.py index 6365ea2..aeeaecd 100644 --- a/app/backend/models/adaptive_rag/grading.py +++ b/app/backend/models/adaptive_rag/grading.py @@ -8,6 +8,7 @@ class GradeDocuments(BaseModel): description="Documents are relevant to the question, 'yes' or 'no'" ) + class GradeHallucinations(BaseModel): """Binary score for hallucination present in generation answer.""" @@ -15,9 +16,10 @@ class GradeHallucinations(BaseModel): description="Answer is grounded in the facts, 'yes' or 'no'" ) + class GradeAnswer(BaseModel): """Binary score to assess answer addresses question.""" binary_score: str = Field( description="Answer addresses the question, 'yes' or 'no'" - ) \ No newline at end of file + ) diff --git a/app/backend/models/adaptive_rag/query.py b/app/backend/models/adaptive_rag/query.py index 7c85eee..b26b7ad 100644 --- a/app/backend/models/adaptive_rag/query.py +++ b/app/backend/models/adaptive_rag/query.py @@ -4,6 +4,6 @@ from pydantic import BaseModel, Field class QueryRequest(BaseModel): query: str = Field(..., description="The question to ask the model") + class QueryResponse(BaseModel): response: str = Field(..., description="The model's response") - diff --git a/app/backend/models/adaptive_rag/routing.py b/app/backend/models/adaptive_rag/routing.py index 569daeb..05ed8f2 100644 --- a/app/backend/models/adaptive_rag/routing.py +++ b/app/backend/models/adaptive_rag/routing.py @@ -9,4 +9,4 @@ class RouteQuery(BaseModel): datasource: Literal["vectorstore", "web_search"] = Field( ..., description="Given a user question choose to route it to web search or a vectorstore.", - ) \ No newline at end of file + ) diff --git a/app/backend/setup.cfg b/app/backend/setup.cfg new file mode 100644 index 0000000..ec4d2a5 --- /dev/null +++ b/app/backend/setup.cfg @@ -0,0 +1,2 @@ +[flake8] +max-line-length = 119 \ No newline at end of file diff --git a/app/backend/tests/api/test_chatbot.py b/app/backend/tests/api/test_chatbot.py deleted file mode 100644 index f861c03..0000000 --- a/app/backend/tests/api/test_chatbot.py +++ /dev/null @@ -1,11 +0,0 @@ -import json -import os -import sys -import unittest -from unittest.mock import AsyncMock, MagicMock - -from fastapi import WebSocket, WebSocketDisconnect - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) - -from api.chatbot import llm_chat, manager, websocket_endpoint diff --git a/app/backend/tests/api/test_utils.py b/app/backend/tests/api/test_utils.py index 81f168c..e65fb95 100644 --- a/app/backend/tests/api/test_utils.py +++ b/app/backend/tests/api/test_utils.py @@ -5,11 +5,12 @@ from unittest.mock import AsyncMock, MagicMock from fastapi import WebSocket -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) - from api.utils import ConnectionManager +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + +# Test for ConnectionManager class class TestConnectionManager(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.manager = ConnectionManager() @@ -38,8 +39,13 @@ class TestConnectionManager(unittest.IsolatedAsyncioTestCase): self.manager.active_connections = [mock_websocket1, mock_websocket2] message = "Broadcast message" await self.manager.broadcast(message) - mock_websocket1.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}') - mock_websocket2.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}') + mock_websocket1.send_text.assert_awaited_once_with( + '{"type": "message", "payload": "Broadcast message"}' + ) + mock_websocket2.send_text.assert_awaited_once_with( + '{"type": "message", "payload": "Broadcast message"}' + ) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main()