2025-04-14 06:32:28 +08:00

55 lines
1.7 KiB
Python

import json
import os
from decouple import config
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from langchain_deepseek import ChatDeepSeek
from models.adaptive_rag import grading, query, routing
from .utils import ConnectionManager
router = APIRouter()
# Load environment variables
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
os.environ["TAVILY_API_KEY"] = config("TAVILY_API_KEY", cast=str)
# Initialize the DeepSeek chat model
llm_chat = ChatDeepSeek(
model="deepseek-chat",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
)
# Initialize the connection manager
manager = ConnectionManager()
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
try:
data_json = json.loads(data)
if isinstance(data_json, list) and len(data_json) > 0 and 'content' in data_json[0]:
async for chunk in llm_chat.astream(data_json[0]['content']):
await manager.send_personal_message(json.dumps({"type": "message", "payload": chunk.content}), websocket)
else:
await manager.send_personal_message("Invalid message format", websocket)
except json.JSONDecodeError:
await manager.broadcast("Invalid JSON message")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast("Client disconnected")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast("Client disconnected")