This commit is contained in:
leehk 2025-04-17 13:11:27 +08:00
parent 48695c964a
commit 63cbd0b9bb
10 changed files with 45 additions and 47 deletions

View File

@ -5,8 +5,6 @@ from decouple import config
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from langchain_deepseek import ChatDeepSeek
from models.adaptive_rag import grading, query, routing
from .utils import ConnectionManager
router = APIRouter()
@ -17,7 +15,7 @@ os.environ["TAVILY_API_KEY"] = config("TAVILY_API_KEY", cast=str)
# Initialize the DeepSeek chat model
llm_chat = ChatDeepSeek(
model="deepseek-chat",
model="deepseek-chat",
temperature=0,
max_tokens=None,
timeout=None,
@ -27,28 +25,36 @@ llm_chat = ChatDeepSeek(
# Initialize the connection manager
manager = ConnectionManager()
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
try:
data_json = json.loads(data)
if isinstance(data_json, list) and len(data_json) > 0 and 'content' in data_json[0]:
async for chunk in llm_chat.astream(data_json[0]['content']):
await manager.send_personal_message(json.dumps({"type": "message", "payload": chunk.content}), websocket)
else:
await manager.send_personal_message("Invalid message format", websocket)
data = await websocket.receive_text()
except json.JSONDecodeError:
await manager.broadcast("Invalid JSON message")
try:
data_json = json.loads(data)
if (
isinstance(data_json, list)
and len(data_json) > 0
and "content" in data_json[0]
):
async for chunk in llm_chat.astream(data_json[0]["content"]):
await manager.send_personal_message(
json.dumps({"type": "message", "payload": chunk.content}),
websocket,
)
else:
await manager.send_personal_message(
"Invalid message format", websocket
)
except json.JSONDecodeError:
await manager.broadcast("Invalid JSON message")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast("Client disconnected")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast("Client disconnected")

View File

@ -22,4 +22,3 @@ class ConnectionManager:
json_message = {"type": "message", "payload": message}
for connection in self.active_connections:
await connection.send_text(json.dumps(json_message))

View File

@ -14,4 +14,4 @@ class Settings(BaseSettings):
@lru_cache()
def get_settings() -> BaseSettings:
log.info("Loading config settings from the environment...")
return Settings()
return Settings()

View File

@ -1,21 +1,19 @@
import logging
import uvicorn
from fastapi import Depends, FastAPI
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from api import chatbot, ping
from config import Settings, get_settings
log = logging.getLogger("uvicorn")
origins = ["http://localhost:8004"]
def create_application() -> FastAPI:
application = FastAPI()
application.include_router(ping.router, tags=["ping"])
application.include_router(
chatbot.router, tags=["chatbot"])
application.include_router(chatbot.router, tags=["chatbot"])
return application
@ -28,7 +26,3 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
# if __name__ == "__main__":
# uvicorn.run("main:app", host="0.0.0.0", port=80, reload=True)

View File

@ -8,6 +8,7 @@ class GradeDocuments(BaseModel):
description="Documents are relevant to the question, 'yes' or 'no'"
)
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
@ -15,9 +16,10 @@ class GradeHallucinations(BaseModel):
description="Answer is grounded in the facts, 'yes' or 'no'"
)
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)
)

View File

@ -4,6 +4,6 @@ from pydantic import BaseModel, Field
class QueryRequest(BaseModel):
query: str = Field(..., description="The question to ask the model")
class QueryResponse(BaseModel):
response: str = Field(..., description="The model's response")

View File

@ -9,4 +9,4 @@ class RouteQuery(BaseModel):
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
)

2
app/backend/setup.cfg Normal file
View File

@ -0,0 +1,2 @@
[flake8]
max-line-length = 119

View File

@ -1,11 +0,0 @@
import json
import os
import sys
import unittest
from unittest.mock import AsyncMock, MagicMock
from fastapi import WebSocket, WebSocketDisconnect
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from api.chatbot import llm_chat, manager, websocket_endpoint

View File

@ -5,11 +5,12 @@ from unittest.mock import AsyncMock, MagicMock
from fastapi import WebSocket
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from api.utils import ConnectionManager
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
# Test for ConnectionManager class
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.manager = ConnectionManager()
@ -38,8 +39,13 @@ class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
self.manager.active_connections = [mock_websocket1, mock_websocket2]
message = "Broadcast message"
await self.manager.broadcast(message)
mock_websocket1.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}')
mock_websocket2.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}')
mock_websocket1.send_text.assert_awaited_once_with(
'{"type": "message", "payload": "Broadcast message"}'
)
mock_websocket2.send_text.assert_awaited_once_with(
'{"type": "message", "payload": "Broadcast message"}'
)
if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()