This commit is contained in:
leehk 2025-04-17 13:11:27 +08:00
parent 48695c964a
commit 63cbd0b9bb
10 changed files with 45 additions and 47 deletions

View File

@ -5,8 +5,6 @@ from decouple import config
from fastapi import APIRouter, WebSocket, WebSocketDisconnect from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from langchain_deepseek import ChatDeepSeek from langchain_deepseek import ChatDeepSeek
from models.adaptive_rag import grading, query, routing
from .utils import ConnectionManager from .utils import ConnectionManager
router = APIRouter() router = APIRouter()
@ -27,28 +25,36 @@ llm_chat = ChatDeepSeek(
# Initialize the connection manager # Initialize the connection manager
manager = ConnectionManager() manager = ConnectionManager()
@router.websocket("/ws") @router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket) await manager.connect(websocket)
try: try:
while True: while True:
data = await websocket.receive_text() data = await websocket.receive_text()
try: try:
data_json = json.loads(data) data_json = json.loads(data)
if isinstance(data_json, list) and len(data_json) > 0 and 'content' in data_json[0]: if (
async for chunk in llm_chat.astream(data_json[0]['content']): isinstance(data_json, list)
await manager.send_personal_message(json.dumps({"type": "message", "payload": chunk.content}), websocket) and len(data_json) > 0
else: and "content" in data_json[0]
await manager.send_personal_message("Invalid message format", websocket) ):
async for chunk in llm_chat.astream(data_json[0]["content"]):
await manager.send_personal_message(
json.dumps({"type": "message", "payload": chunk.content}),
websocket,
)
else:
await manager.send_personal_message(
"Invalid message format", websocket
)
except json.JSONDecodeError: except json.JSONDecodeError:
await manager.broadcast("Invalid JSON message") await manager.broadcast("Invalid JSON message")
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(websocket) manager.disconnect(websocket)
await manager.broadcast("Client disconnected") await manager.broadcast("Client disconnected")
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(websocket) manager.disconnect(websocket)
await manager.broadcast("Client disconnected") await manager.broadcast("Client disconnected")

View File

@ -22,4 +22,3 @@ class ConnectionManager:
json_message = {"type": "message", "payload": message} json_message = {"type": "message", "payload": message}
for connection in self.active_connections: for connection in self.active_connections:
await connection.send_text(json.dumps(json_message)) await connection.send_text(json.dumps(json_message))

View File

@ -1,21 +1,19 @@
import logging import logging
import uvicorn from fastapi import FastAPI
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from api import chatbot, ping from api import chatbot, ping
from config import Settings, get_settings
log = logging.getLogger("uvicorn") log = logging.getLogger("uvicorn")
origins = ["http://localhost:8004"] origins = ["http://localhost:8004"]
def create_application() -> FastAPI: def create_application() -> FastAPI:
application = FastAPI() application = FastAPI()
application.include_router(ping.router, tags=["ping"]) application.include_router(ping.router, tags=["ping"])
application.include_router( application.include_router(chatbot.router, tags=["chatbot"])
chatbot.router, tags=["chatbot"])
return application return application
@ -28,7 +26,3 @@ app.add_middleware(
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
# if __name__ == "__main__":
# uvicorn.run("main:app", host="0.0.0.0", port=80, reload=True)

View File

@ -8,6 +8,7 @@ class GradeDocuments(BaseModel):
description="Documents are relevant to the question, 'yes' or 'no'" description="Documents are relevant to the question, 'yes' or 'no'"
) )
class GradeHallucinations(BaseModel): class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer.""" """Binary score for hallucination present in generation answer."""
@ -15,6 +16,7 @@ class GradeHallucinations(BaseModel):
description="Answer is grounded in the facts, 'yes' or 'no'" description="Answer is grounded in the facts, 'yes' or 'no'"
) )
class GradeAnswer(BaseModel): class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question.""" """Binary score to assess answer addresses question."""

View File

@ -4,6 +4,6 @@ from pydantic import BaseModel, Field
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str = Field(..., description="The question to ask the model") query: str = Field(..., description="The question to ask the model")
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
response: str = Field(..., description="The model's response") response: str = Field(..., description="The model's response")

2
app/backend/setup.cfg Normal file
View File

@ -0,0 +1,2 @@
[flake8]
max-line-length = 119

View File

@ -1,11 +0,0 @@
import json
import os
import sys
import unittest
from unittest.mock import AsyncMock, MagicMock
from fastapi import WebSocket, WebSocketDisconnect
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from api.chatbot import llm_chat, manager, websocket_endpoint

View File

@ -5,11 +5,12 @@ from unittest.mock import AsyncMock, MagicMock
from fastapi import WebSocket from fastapi import WebSocket
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from api.utils import ConnectionManager from api.utils import ConnectionManager
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
# Test for ConnectionManager class
class TestConnectionManager(unittest.IsolatedAsyncioTestCase): class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.manager = ConnectionManager() self.manager = ConnectionManager()
@ -38,8 +39,13 @@ class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
self.manager.active_connections = [mock_websocket1, mock_websocket2] self.manager.active_connections = [mock_websocket1, mock_websocket2]
message = "Broadcast message" message = "Broadcast message"
await self.manager.broadcast(message) await self.manager.broadcast(message)
mock_websocket1.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}') mock_websocket1.send_text.assert_awaited_once_with(
mock_websocket2.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}') '{"type": "message", "payload": "Broadcast message"}'
)
mock_websocket2.send_text.assert_awaited_once_with(
'{"type": "message", "payload": "Broadcast message"}'
)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()