This commit is contained in:
leehk 2025-04-17 13:11:27 +08:00
parent 48695c964a
commit 63cbd0b9bb
10 changed files with 45 additions and 47 deletions

View File

@ -5,8 +5,6 @@ from decouple import config
from fastapi import APIRouter, WebSocket, WebSocketDisconnect from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from langchain_deepseek import ChatDeepSeek from langchain_deepseek import ChatDeepSeek
from models.adaptive_rag import grading, query, routing
from .utils import ConnectionManager from .utils import ConnectionManager
router = APIRouter() router = APIRouter()
@ -17,7 +15,7 @@ os.environ["TAVILY_API_KEY"] = config("TAVILY_API_KEY", cast=str)
# Initialize the DeepSeek chat model # Initialize the DeepSeek chat model
llm_chat = ChatDeepSeek( llm_chat = ChatDeepSeek(
model="deepseek-chat", model="deepseek-chat",
temperature=0, temperature=0,
max_tokens=None, max_tokens=None,
timeout=None, timeout=None,
@ -27,28 +25,36 @@ llm_chat = ChatDeepSeek(
# Initialize the connection manager # Initialize the connection manager
manager = ConnectionManager() manager = ConnectionManager()
@router.websocket("/ws") @router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket) await manager.connect(websocket)
try: try:
while True: while True:
data = await websocket.receive_text() data = await websocket.receive_text()
try:
data_json = json.loads(data)
if isinstance(data_json, list) and len(data_json) > 0 and 'content' in data_json[0]:
async for chunk in llm_chat.astream(data_json[0]['content']):
await manager.send_personal_message(json.dumps({"type": "message", "payload": chunk.content}), websocket)
else:
await manager.send_personal_message("Invalid message format", websocket)
except json.JSONDecodeError: try:
await manager.broadcast("Invalid JSON message") data_json = json.loads(data)
if (
isinstance(data_json, list)
and len(data_json) > 0
and "content" in data_json[0]
):
async for chunk in llm_chat.astream(data_json[0]["content"]):
await manager.send_personal_message(
json.dumps({"type": "message", "payload": chunk.content}),
websocket,
)
else:
await manager.send_personal_message(
"Invalid message format", websocket
)
except json.JSONDecodeError:
await manager.broadcast("Invalid JSON message")
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(websocket) manager.disconnect(websocket)
await manager.broadcast("Client disconnected") await manager.broadcast("Client disconnected")
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(websocket) manager.disconnect(websocket)
await manager.broadcast("Client disconnected") await manager.broadcast("Client disconnected")

View File

@ -22,4 +22,3 @@ class ConnectionManager:
json_message = {"type": "message", "payload": message} json_message = {"type": "message", "payload": message}
for connection in self.active_connections: for connection in self.active_connections:
await connection.send_text(json.dumps(json_message)) await connection.send_text(json.dumps(json_message))

View File

@ -14,4 +14,4 @@ class Settings(BaseSettings):
@lru_cache() @lru_cache()
def get_settings() -> BaseSettings: def get_settings() -> BaseSettings:
log.info("Loading config settings from the environment...") log.info("Loading config settings from the environment...")
return Settings() return Settings()

View File

@ -1,21 +1,19 @@
import logging import logging
import uvicorn from fastapi import FastAPI
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from api import chatbot, ping from api import chatbot, ping
from config import Settings, get_settings
log = logging.getLogger("uvicorn") log = logging.getLogger("uvicorn")
origins = ["http://localhost:8004"] origins = ["http://localhost:8004"]
def create_application() -> FastAPI: def create_application() -> FastAPI:
application = FastAPI() application = FastAPI()
application.include_router(ping.router, tags=["ping"]) application.include_router(ping.router, tags=["ping"])
application.include_router( application.include_router(chatbot.router, tags=["chatbot"])
chatbot.router, tags=["chatbot"])
return application return application
@ -28,7 +26,3 @@ app.add_middleware(
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
# if __name__ == "__main__":
# uvicorn.run("main:app", host="0.0.0.0", port=80, reload=True)

View File

@ -8,6 +8,7 @@ class GradeDocuments(BaseModel):
description="Documents are relevant to the question, 'yes' or 'no'" description="Documents are relevant to the question, 'yes' or 'no'"
) )
class GradeHallucinations(BaseModel): class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer.""" """Binary score for hallucination present in generation answer."""
@ -15,9 +16,10 @@ class GradeHallucinations(BaseModel):
description="Answer is grounded in the facts, 'yes' or 'no'" description="Answer is grounded in the facts, 'yes' or 'no'"
) )
class GradeAnswer(BaseModel): class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question.""" """Binary score to assess answer addresses question."""
binary_score: str = Field( binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'" description="Answer addresses the question, 'yes' or 'no'"
) )

View File

@ -4,6 +4,6 @@ from pydantic import BaseModel, Field
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str = Field(..., description="The question to ask the model") query: str = Field(..., description="The question to ask the model")
class QueryResponse(BaseModel): class QueryResponse(BaseModel):
response: str = Field(..., description="The model's response") response: str = Field(..., description="The model's response")

View File

@ -9,4 +9,4 @@ class RouteQuery(BaseModel):
datasource: Literal["vectorstore", "web_search"] = Field( datasource: Literal["vectorstore", "web_search"] = Field(
..., ...,
description="Given a user question choose to route it to web search or a vectorstore.", description="Given a user question choose to route it to web search or a vectorstore.",
) )

2
app/backend/setup.cfg Normal file
View File

@ -0,0 +1,2 @@
[flake8]
max-line-length = 119

View File

@ -1,11 +0,0 @@
import json
import os
import sys
import unittest
from unittest.mock import AsyncMock, MagicMock
from fastapi import WebSocket, WebSocketDisconnect
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from api.chatbot import llm_chat, manager, websocket_endpoint

View File

@ -5,11 +5,12 @@ from unittest.mock import AsyncMock, MagicMock
from fastapi import WebSocket from fastapi import WebSocket
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from api.utils import ConnectionManager from api.utils import ConnectionManager
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
# Test for ConnectionManager class
class TestConnectionManager(unittest.IsolatedAsyncioTestCase): class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
self.manager = ConnectionManager() self.manager = ConnectionManager()
@ -38,8 +39,13 @@ class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
self.manager.active_connections = [mock_websocket1, mock_websocket2] self.manager.active_connections = [mock_websocket1, mock_websocket2]
message = "Broadcast message" message = "Broadcast message"
await self.manager.broadcast(message) await self.manager.broadcast(message)
mock_websocket1.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}') mock_websocket1.send_text.assert_awaited_once_with(
mock_websocket2.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}') '{"type": "message", "payload": "Broadcast message"}'
)
mock_websocket2.send_text.assert_awaited_once_with(
'{"type": "message", "payload": "Broadcast message"}'
)
if __name__ == '__main__':
unittest.main() if __name__ == "__main__":
unittest.main()