mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-06 07:03:18 +08:00
update
This commit is contained in:
parent
48695c964a
commit
63cbd0b9bb
@ -5,8 +5,6 @@ from decouple import config
|
|||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
from langchain_deepseek import ChatDeepSeek
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
|
||||||
from models.adaptive_rag import grading, query, routing
|
|
||||||
|
|
||||||
from .utils import ConnectionManager
|
from .utils import ConnectionManager
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -17,7 +15,7 @@ os.environ["TAVILY_API_KEY"] = config("TAVILY_API_KEY", cast=str)
|
|||||||
|
|
||||||
# Initialize the DeepSeek chat model
|
# Initialize the DeepSeek chat model
|
||||||
llm_chat = ChatDeepSeek(
|
llm_chat = ChatDeepSeek(
|
||||||
model="deepseek-chat",
|
model="deepseek-chat",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
@ -27,28 +25,36 @@ llm_chat = ChatDeepSeek(
|
|||||||
# Initialize the connection manager
|
# Initialize the connection manager
|
||||||
manager = ConnectionManager()
|
manager = ConnectionManager()
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws")
|
@router.websocket("/ws")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
await manager.connect(websocket)
|
await manager.connect(websocket)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await websocket.receive_text()
|
data = await websocket.receive_text()
|
||||||
|
|
||||||
try:
|
|
||||||
data_json = json.loads(data)
|
|
||||||
if isinstance(data_json, list) and len(data_json) > 0 and 'content' in data_json[0]:
|
|
||||||
async for chunk in llm_chat.astream(data_json[0]['content']):
|
|
||||||
await manager.send_personal_message(json.dumps({"type": "message", "payload": chunk.content}), websocket)
|
|
||||||
else:
|
|
||||||
await manager.send_personal_message("Invalid message format", websocket)
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
try:
|
||||||
await manager.broadcast("Invalid JSON message")
|
data_json = json.loads(data)
|
||||||
|
if (
|
||||||
|
isinstance(data_json, list)
|
||||||
|
and len(data_json) > 0
|
||||||
|
and "content" in data_json[0]
|
||||||
|
):
|
||||||
|
async for chunk in llm_chat.astream(data_json[0]["content"]):
|
||||||
|
await manager.send_personal_message(
|
||||||
|
json.dumps({"type": "message", "payload": chunk.content}),
|
||||||
|
websocket,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await manager.send_personal_message(
|
||||||
|
"Invalid message format", websocket
|
||||||
|
)
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
await manager.broadcast("Invalid JSON message")
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
manager.disconnect(websocket)
|
manager.disconnect(websocket)
|
||||||
await manager.broadcast("Client disconnected")
|
await manager.broadcast("Client disconnected")
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
manager.disconnect(websocket)
|
manager.disconnect(websocket)
|
||||||
await manager.broadcast("Client disconnected")
|
await manager.broadcast("Client disconnected")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -22,4 +22,3 @@ class ConnectionManager:
|
|||||||
json_message = {"type": "message", "payload": message}
|
json_message = {"type": "message", "payload": message}
|
||||||
for connection in self.active_connections:
|
for connection in self.active_connections:
|
||||||
await connection.send_text(json.dumps(json_message))
|
await connection.send_text(json.dumps(json_message))
|
||||||
|
|
||||||
|
|||||||
@ -14,4 +14,4 @@ class Settings(BaseSettings):
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_settings() -> BaseSettings:
|
def get_settings() -> BaseSettings:
|
||||||
log.info("Loading config settings from the environment...")
|
log.info("Loading config settings from the environment...")
|
||||||
return Settings()
|
return Settings()
|
||||||
|
|||||||
@ -1,21 +1,19 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import uvicorn
|
from fastapi import FastAPI
|
||||||
from fastapi import Depends, FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from api import chatbot, ping
|
from api import chatbot, ping
|
||||||
from config import Settings, get_settings
|
|
||||||
|
|
||||||
log = logging.getLogger("uvicorn")
|
log = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
origins = ["http://localhost:8004"]
|
origins = ["http://localhost:8004"]
|
||||||
|
|
||||||
|
|
||||||
def create_application() -> FastAPI:
|
def create_application() -> FastAPI:
|
||||||
application = FastAPI()
|
application = FastAPI()
|
||||||
application.include_router(ping.router, tags=["ping"])
|
application.include_router(ping.router, tags=["ping"])
|
||||||
application.include_router(
|
application.include_router(chatbot.router, tags=["chatbot"])
|
||||||
chatbot.router, tags=["chatbot"])
|
|
||||||
return application
|
return application
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +26,3 @@ app.add_middleware(
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# uvicorn.run("main:app", host="0.0.0.0", port=80, reload=True)
|
|
||||||
@ -8,6 +8,7 @@ class GradeDocuments(BaseModel):
|
|||||||
description="Documents are relevant to the question, 'yes' or 'no'"
|
description="Documents are relevant to the question, 'yes' or 'no'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GradeHallucinations(BaseModel):
|
class GradeHallucinations(BaseModel):
|
||||||
"""Binary score for hallucination present in generation answer."""
|
"""Binary score for hallucination present in generation answer."""
|
||||||
|
|
||||||
@ -15,9 +16,10 @@ class GradeHallucinations(BaseModel):
|
|||||||
description="Answer is grounded in the facts, 'yes' or 'no'"
|
description="Answer is grounded in the facts, 'yes' or 'no'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GradeAnswer(BaseModel):
|
class GradeAnswer(BaseModel):
|
||||||
"""Binary score to assess answer addresses question."""
|
"""Binary score to assess answer addresses question."""
|
||||||
|
|
||||||
binary_score: str = Field(
|
binary_score: str = Field(
|
||||||
description="Answer addresses the question, 'yes' or 'no'"
|
description="Answer addresses the question, 'yes' or 'no'"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,6 +4,6 @@ from pydantic import BaseModel, Field
|
|||||||
class QueryRequest(BaseModel):
|
class QueryRequest(BaseModel):
|
||||||
query: str = Field(..., description="The question to ask the model")
|
query: str = Field(..., description="The question to ask the model")
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
class QueryResponse(BaseModel):
|
||||||
response: str = Field(..., description="The model's response")
|
response: str = Field(..., description="The model's response")
|
||||||
|
|
||||||
|
|||||||
@ -9,4 +9,4 @@ class RouteQuery(BaseModel):
|
|||||||
datasource: Literal["vectorstore", "web_search"] = Field(
|
datasource: Literal["vectorstore", "web_search"] = Field(
|
||||||
...,
|
...,
|
||||||
description="Given a user question choose to route it to web search or a vectorstore.",
|
description="Given a user question choose to route it to web search or a vectorstore.",
|
||||||
)
|
)
|
||||||
|
|||||||
2
app/backend/setup.cfg
Normal file
2
app/backend/setup.cfg
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[flake8]
|
||||||
|
max-line-length = 119
|
||||||
@ -1,11 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
from fastapi import WebSocket, WebSocketDisconnect
|
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
|
||||||
|
|
||||||
from api.chatbot import llm_chat, manager, websocket_endpoint
|
|
||||||
@ -5,11 +5,12 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
|
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
|
||||||
|
|
||||||
from api.utils import ConnectionManager
|
from api.utils import ConnectionManager
|
||||||
|
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||||
|
|
||||||
|
|
||||||
|
# Test for ConnectionManager class
|
||||||
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
self.manager = ConnectionManager()
|
self.manager = ConnectionManager()
|
||||||
@ -38,8 +39,13 @@ class TestConnectionManager(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.manager.active_connections = [mock_websocket1, mock_websocket2]
|
self.manager.active_connections = [mock_websocket1, mock_websocket2]
|
||||||
message = "Broadcast message"
|
message = "Broadcast message"
|
||||||
await self.manager.broadcast(message)
|
await self.manager.broadcast(message)
|
||||||
mock_websocket1.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}')
|
mock_websocket1.send_text.assert_awaited_once_with(
|
||||||
mock_websocket2.send_text.assert_awaited_once_with('{"type": "message", "payload": "Broadcast message"}')
|
'{"type": "message", "payload": "Broadcast message"}'
|
||||||
|
)
|
||||||
|
mock_websocket2.send_text.assert_awaited_once_with(
|
||||||
|
'{"type": "message", "payload": "Broadcast message"}'
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user