mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-19 13:23:23 +08:00
updated
This commit is contained in:
parent
9a5303e157
commit
3752d85fde
51
app/backend/tests/api/test_chatbot.py
Normal file
51
app/backend/tests/api/test_chatbot.py
Normal file
@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi import WebSocket
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Patch langchain and other heavy dependencies for import
|
||||
sys.modules['langchain_deepseek'] = MagicMock()
|
||||
sys.modules['langchain_huggingface'] = MagicMock()
|
||||
sys.modules['langchain_community.vectorstores.chroma'] = MagicMock()
|
||||
sys.modules['langchain_community.tools.tavily_search'] = MagicMock()
|
||||
sys.modules['langchain_core.prompts'] = MagicMock()
|
||||
sys.modules['langchain_core.output_parsers'] = MagicMock()
|
||||
sys.modules['langchain.prompts'] = MagicMock()
|
||||
sys.modules['langchain.schema'] = MagicMock()
|
||||
sys.modules['langgraph.graph'] = MagicMock()
|
||||
|
||||
from api import chatbot
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(chatbot.router)
|
||||
return TestClient(app)
|
||||
|
||||
def test_router_exists():
|
||||
assert hasattr(chatbot, 'router')
|
||||
|
||||
def test_env_vars_loaded(monkeypatch):
|
||||
monkeypatch.setenv('DEEPSEEK_API_KEY', 'dummy')
|
||||
monkeypatch.setenv('TAVILY_API_KEY', 'dummy')
|
||||
# Re-import to trigger env loading
|
||||
import importlib
|
||||
importlib.reload(chatbot)
|
||||
assert True
|
||||
|
||||
def test_websocket_endpoint_accepts(monkeypatch):
|
||||
# Patch ConnectionManager
|
||||
mock_manager = MagicMock()
|
||||
monkeypatch.setattr(chatbot, 'manager', mock_manager)
|
||||
ws = MagicMock(spec=WebSocket)
|
||||
ws.receive_text = MagicMock(side_effect=[pytest.raises(StopIteration)])
|
||||
ws.accept = MagicMock()
|
||||
# Should not raise
|
||||
try:
|
||||
coro = chatbot.websocket_endpoint(ws)
|
||||
assert hasattr(coro, '__await__')
|
||||
except Exception as e:
|
||||
pytest.fail(f"websocket_endpoint raised: {e}")
|
||||
14
app/backend/tests/models/adaptive_rag/test_grading.py
Normal file
14
app/backend/tests/models/adaptive_rag/test_grading.py
Normal file
@ -0,0 +1,14 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import grading
|
||||
|
||||
def test_grade_documents_class():
|
||||
doc = grading.GradeDocuments(binary_score='yes')
|
||||
assert doc.binary_score == 'yes'
|
||||
|
||||
def test_grade_hallucinations_class():
|
||||
doc = grading.GradeHallucinations(binary_score='no')
|
||||
assert doc.binary_score == 'no'
|
||||
|
||||
def test_grade_answer_class():
|
||||
doc = grading.GradeAnswer(binary_score='yes')
|
||||
assert doc.binary_score == 'yes'
|
||||
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import prompts_library
|
||||
|
||||
def test_prompts_are_strings():
|
||||
assert isinstance(prompts_library.system_router, str)
|
||||
assert isinstance(prompts_library.system_retriever_grader, str)
|
||||
assert isinstance(prompts_library.system_hallucination_grader, str)
|
||||
assert isinstance(prompts_library.system_answer_grader, str)
|
||||
assert isinstance(prompts_library.system_question_rewriter, str)
|
||||
assert isinstance(prompts_library.qa_prompt_template, str)
|
||||
8
app/backend/tests/models/adaptive_rag/test_query.py
Normal file
8
app/backend/tests/models/adaptive_rag/test_query.py
Normal file
@ -0,0 +1,8 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import query
|
||||
|
||||
def test_query_request_and_response():
|
||||
req = query.QueryRequest(query="What is AI?")
|
||||
assert req.query == "What is AI?"
|
||||
resp = query.QueryResponse(response="Artificial Intelligence")
|
||||
assert resp.response == "Artificial Intelligence"
|
||||
6
app/backend/tests/models/adaptive_rag/test_routing.py
Normal file
6
app/backend/tests/models/adaptive_rag/test_routing.py
Normal file
@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import routing
|
||||
|
||||
def test_route_query_class():
|
||||
route = routing.RouteQuery(datasource="vectorstore")
|
||||
assert route.datasource == "vectorstore"
|
||||
10
app/backend/tests/test_config_and_main.py
Normal file
10
app/backend/tests/test_config_and_main.py
Normal file
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
from importlib import import_module
|
||||
|
||||
def test_config_import():
|
||||
mod = import_module('config')
|
||||
assert mod is not None
|
||||
|
||||
def test_main_import():
|
||||
mod = import_module('main')
|
||||
assert mod is not None
|
||||
@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
from importlib import import_module
|
||||
|
||||
def test_initialize_sentence_transformer_import():
|
||||
mod = import_module('utils.initialize_sentence_transformer')
|
||||
assert mod is not None
|
||||
Loading…
x
Reference in New Issue
Block a user