mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-28 01:43:18 +08:00
update
This commit is contained in:
parent
48695c964a
commit
63cbd0b9bb
@ -5,8 +5,6 @@ from decouple import config
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
from models.adaptive_rag import grading, query, routing
|
||||
|
||||
from .utils import ConnectionManager
|
||||
|
||||
router = APIRouter()
|
||||
@ -17,7 +15,7 @@ os.environ["TAVILY_API_KEY"] = config("TAVILY_API_KEY", cast=str)
|
||||
|
||||
# Initialize the DeepSeek chat model
|
||||
llm_chat = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
@ -27,28 +25,36 @@ llm_chat = ChatDeepSeek(
|
||||
# Initialize the connection manager
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
|
||||
try:
|
||||
data_json = json.loads(data)
|
||||
if isinstance(data_json, list) and len(data_json) > 0 and 'content' in data_json[0]:
|
||||
async for chunk in llm_chat.astream(data_json[0]['content']):
|
||||
await manager.send_personal_message(json.dumps({"type": "message", "payload": chunk.content}), websocket)
|
||||
else:
|
||||
await manager.send_personal_message("Invalid message format", websocket)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await manager.broadcast("Invalid JSON message")
|
||||
try:
|
||||
data_json = json.loads(data)
|
||||
if (
|
||||
isinstance(data_json, list)
|
||||
and len(data_json) > 0
|
||||
and "content" in data_json[0]
|
||||
):
|
||||
async for chunk in llm_chat.astream(data_json[0]["content"]):
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "message", "payload": chunk.content}),
|
||||
websocket,
|
||||
)
|
||||
else:
|
||||
await manager.send_personal_message(
|
||||
"Invalid message format", websocket
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await manager.broadcast("Invalid JSON message")
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
await manager.broadcast("Client disconnected")
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
await manager.broadcast("Client disconnected")
|
||||
|
||||
|
||||
|
||||
@ -22,4 +22,3 @@ class ConnectionManager:
|
||||
json_message = {"type": "message", "payload": message}
|
||||
for connection in self.active_connections:
|
||||
await connection.send_text(json.dumps(json_message))
|
||||
|
||||
|
||||
@ -14,4 +14,4 @@ class Settings(BaseSettings):
|
||||
@lru_cache()
|
||||
def get_settings() -> BaseSettings:
|
||||
log.info("Loading config settings from the environment...")
|
||||
return Settings()
|
||||
return Settings()
|
||||
|
||||
@ -1,21 +1,19 @@
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from api import chatbot, ping
|
||||
from config import Settings, get_settings
|
||||
|
||||
log = logging.getLogger("uvicorn")
|
||||
|
||||
origins = ["http://localhost:8004"]
|
||||
|
||||
|
||||
def create_application() -> FastAPI:
|
||||
application = FastAPI()
|
||||
application.include_router(ping.router, tags=["ping"])
|
||||
application.include_router(
|
||||
chatbot.router, tags=["chatbot"])
|
||||
application.include_router(chatbot.router, tags=["chatbot"])
|
||||
return application
|
||||
|
||||
|
||||
@ -28,7 +26,3 @@ app.add_middleware(
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# uvicorn.run("main:app", host="0.0.0.0", port=80, reload=True)
|
||||
@ -8,6 +8,7 @@ class GradeDocuments(BaseModel):
|
||||
description="Documents are relevant to the question, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
|
||||
class GradeHallucinations(BaseModel):
|
||||
"""Binary score for hallucination present in generation answer."""
|
||||
|
||||
@ -15,9 +16,10 @@ class GradeHallucinations(BaseModel):
|
||||
description="Answer is grounded in the facts, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
|
||||
class GradeAnswer(BaseModel):
|
||||
"""Binary score to assess answer addresses question."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Answer addresses the question, 'yes' or 'no'"
|
||||
)
|
||||
)
|
||||
|
||||
@ -4,6 +4,6 @@ from pydantic import BaseModel, Field
|
||||
class QueryRequest(BaseModel):
|
||||
query: str = Field(..., description="The question to ask the model")
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
response: str = Field(..., description="The model's response")
|
||||
|
||||
|
||||
@ -9,4 +9,4 @@ class RouteQuery(BaseModel):
|
||||
datasource: Literal["vectorstore", "web_search"] = Field(
|
||||
...,
|
||||
description="Given a user question choose to route it to web search or a vectorstore.",
|
||||
)
|
||||
)
|
||||
|
||||
2
app/backend/setup.cfg
Normal file
2
app/backend/setup.cfg
Normal file
@ -0,0 +1,2 @@
|
||||
[flake8]
|
||||
max-line-length = 119
|
||||
@ -1,11 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
||||
|
||||
from api.chatbot import llm_chat, manager, websocket_endpoint
|
||||
@ -5,11 +5,12 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
||||
|
||||
from api.utils import ConnectionManager
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||
|
||||
|
||||
# Test for ConnectionManager class
|
||||
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.manager = ConnectionManager()
|
||||
@ -38,8 +39,13 @@ class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
||||
self.manager.active_connections = [mock_websocket1, mock_websocket2]
|
||||
message = "Broadcast message"
|
||||
await self.manager.broadcast(message)
|
||||
mock_websocket1.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}')
|
||||
mock_websocket2.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}')
|
||||
mock_websocket1.send_text.assert_awaited_once_with(
|
||||
'{"type": "message", "payload": "Broadcast message"}'
|
||||
)
|
||||
mock_websocket2.send_text.assert_awaited_once_with(
|
||||
'{"type": "message", "payload": "Broadcast message"}'
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user