This commit is contained in:
leehk 2025-04-25 15:45:10 +08:00
parent 9a5303e157
commit 3752d85fde
7 changed files with 105 additions and 0 deletions

View File

@ -0,0 +1,51 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock
from fastapi import WebSocket
import sys
import types
# Patch langchain and other heavy dependencies for import
sys.modules['langchain_deepseek'] = MagicMock()
sys.modules['langchain_huggingface'] = MagicMock()
sys.modules['langchain_community.vectorstores.chroma'] = MagicMock()
sys.modules['langchain_community.tools.tavily_search'] = MagicMock()
sys.modules['langchain_core.prompts'] = MagicMock()
sys.modules['langchain_core.output_parsers'] = MagicMock()
sys.modules['langchain.prompts'] = MagicMock()
sys.modules['langchain.schema'] = MagicMock()
sys.modules['langgraph.graph'] = MagicMock()
from api import chatbot
@pytest.fixture
def client():
from fastapi import FastAPI
app = FastAPI()
app.include_router(chatbot.router)
return TestClient(app)
def test_router_exists():
assert hasattr(chatbot, 'router')
def test_env_vars_loaded(monkeypatch):
monkeypatch.setenv('DEEPSEEK_API_KEY', 'dummy')
monkeypatch.setenv('TAVILY_API_KEY', 'dummy')
# Re-import to trigger env loading
import importlib
importlib.reload(chatbot)
assert True
def test_websocket_endpoint_accepts(monkeypatch):
# Patch ConnectionManager
mock_manager = MagicMock()
monkeypatch.setattr(chatbot, 'manager', mock_manager)
ws = MagicMock(spec=WebSocket)
ws.receive_text = MagicMock(side_effect=[pytest.raises(StopIteration)])
ws.accept = MagicMock()
# Should not raise
try:
coro = chatbot.websocket_endpoint(ws)
assert hasattr(coro, '__await__')
except Exception as e:
pytest.fail(f"websocket_endpoint raised: {e}")

View File

@ -0,0 +1,14 @@
import pytest
from models.adaptive_rag import grading
def test_grade_documents_class():
doc = grading.GradeDocuments(binary_score='yes')
assert doc.binary_score == 'yes'
def test_grade_hallucinations_class():
doc = grading.GradeHallucinations(binary_score='no')
assert doc.binary_score == 'no'
def test_grade_answer_class():
doc = grading.GradeAnswer(binary_score='yes')
assert doc.binary_score == 'yes'

View File

@ -0,0 +1,10 @@
import pytest
from models.adaptive_rag import prompts_library
def test_prompts_are_strings():
assert isinstance(prompts_library.system_router, str)
assert isinstance(prompts_library.system_retriever_grader, str)
assert isinstance(prompts_library.system_hallucination_grader, str)
assert isinstance(prompts_library.system_answer_grader, str)
assert isinstance(prompts_library.system_question_rewriter, str)
assert isinstance(prompts_library.qa_prompt_template, str)

View File

@ -0,0 +1,8 @@
import pytest
from models.adaptive_rag import query
def test_query_request_and_response():
req = query.QueryRequest(query="What is AI?")
assert req.query == "What is AI?"
resp = query.QueryResponse(response="Artificial Intelligence")
assert resp.response == "Artificial Intelligence"

View File

@ -0,0 +1,6 @@
import pytest
from models.adaptive_rag import routing
def test_route_query_class():
route = routing.RouteQuery(datasource="vectorstore")
assert route.datasource == "vectorstore"

View File

@ -0,0 +1,10 @@
import pytest
from importlib import import_module
def test_config_import():
mod = import_module('config')
assert mod is not None
def test_main_import():
mod = import_module('main')
assert mod is not None

View File

@ -0,0 +1,6 @@
import pytest
from importlib import import_module
def test_initialize_sentence_transformer_import():
mod = import_module('utils.initialize_sentence_transformer')
assert mod is not None