mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-19 13:23:23 +08:00
Merge pull request #77 from aimingmed/feature/incorporate-functional-adaptive-rag
Feature/incorporate functional adaptive rag
This commit is contained in:
commit
40cd845ef0
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@ -48,7 +48,7 @@ jobs:
|
||||
image_config:
|
||||
- IMAGE_NAME: backend-aimingmedai
|
||||
BUILD_CONTEXT: ./app/backend
|
||||
DOCKERFILE: ./app/backend/Dockerfile.prod
|
||||
DOCKERFILE: ./app/backend/Dockerfile
|
||||
- IMAGE_NAME: frontend-aimingmedai
|
||||
BUILD_CONTEXT: ./app/frontend
|
||||
DOCKERFILE: ./app/frontend/Dockerfile.test
|
||||
|
||||
2
.github/workflows/template_unit_pytest.yml
vendored
2
.github/workflows/template_unit_pytest.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
ls -al
|
||||
echo "Pipfile content:"
|
||||
cat Pipfile
|
||||
pipenv sync -d
|
||||
pipenv install --dev --skip-lock
|
||||
|
||||
- name: Run tests with pytest
|
||||
working-directory: ${{ env.WORKING_DIR }}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# pull official base image
|
||||
FROM python:3.11-slim-bullseye
|
||||
FROM python:3.11-slim-bookworm
|
||||
|
||||
# create directory for the app user
|
||||
RUN mkdir -p /home/app
|
||||
@ -18,26 +18,39 @@ ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV ENVIRONMENT=dev
|
||||
ENV TESTING=1
|
||||
ENV CUDA_VISIBLE_DEVICES=""
|
||||
|
||||
COPY Pipfile $APP_HOME/
|
||||
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pipenv && rm -rf ~/.cache/pip
|
||||
RUN pipenv install --deploy --dev --no-cache-dir
|
||||
RUN pipenv run pip install torch --force-reinstall --no-cache-dir
|
||||
|
||||
# remove all cached files not needed to save space
|
||||
RUN pip cache purge
|
||||
RUN rm -rf /root/.cache
|
||||
|
||||
# add app
|
||||
COPY . $APP_HOME
|
||||
|
||||
# install python dependencies
|
||||
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pipenv && rm -rf ~/.cache/pip
|
||||
RUN pipenv install --deploy --dev
|
||||
# Create cache directory and set permissions
|
||||
RUN mkdir -p /home/app/.cache/huggingface
|
||||
RUN chown -R app:app /home/app/.cache/huggingface
|
||||
|
||||
# chown all the files to the app user
|
||||
RUN chown -R app:app $APP_HOME
|
||||
|
||||
# change to the app user
|
||||
USER app
|
||||
|
||||
# pytest
|
||||
RUN pipenv run pytest tests --disable-warnings
|
||||
# Run python to initialize download of SentenceTransformer model
|
||||
RUN pipenv run python utils/initialize_sentence_transformer.py
|
||||
|
||||
# pytest
|
||||
RUN export DEEPSEEK_API_KEY=sk-XXXXXXXXXX; export TAVILY_API_KEY=tvly-dev-wXXXXXX;\
|
||||
pipenv run pytest tests --disable-warnings
|
||||
|
||||
# expose the port the app runs on
|
||||
EXPOSE 80
|
||||
|
||||
# run uvicorn
|
||||
CMD ["pipenv", "run", "uvicorn", "main:app", "--reload", "--workers", "1", "--host", "0.0.0.0", "--port", "80"]
|
||||
|
||||
|
||||
@ -13,12 +13,23 @@ ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV ENVIRONMENT=dev
|
||||
ENV TESTING=1
|
||||
ENV CUDA_VISIBLE_DEVICES=""
|
||||
|
||||
# install python dependencies
|
||||
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pipenv && rm -rf ~/.cache/pip
|
||||
COPY ./Pipfile .
|
||||
RUN pipenv install --deploy --dev
|
||||
RUN pipenv install --deploy --dev --no-cache-dir
|
||||
RUN pipenv run pip install torch --force-reinstall --no-cache-dir
|
||||
|
||||
# remove all cached files not needed to save space
|
||||
RUN pip cache purge
|
||||
RUN rm -rf /root/.cache
|
||||
|
||||
# Create cache directory and set permissions
|
||||
RUN mkdir -p /home/app/.cache/huggingface
|
||||
RUN chown -R app:app /home/app/.cache/huggingface
|
||||
RUN chown -R app:app $APP_HOME
|
||||
#
|
||||
# add app
|
||||
COPY . /usr/src/app
|
||||
RUN export DEEPSEEK_API_KEY=sk-XXXXXXXXXX; export TAVILY_API_KEY=tvly-dev-wXXXXXX;\
|
||||
|
||||
@ -7,15 +7,20 @@ name = "pypi"
|
||||
fastapi = "==0.115.9"
|
||||
starlette = "==0.45.3"
|
||||
uvicorn = {version = "==0.26.0", extras = ["standard"]}
|
||||
pydantic-settings = "==2.1.0"
|
||||
pydantic-settings = "*"
|
||||
gunicorn = "==21.0.1"
|
||||
python-decouple = "==3.8"
|
||||
pyyaml = "==6.0.1"
|
||||
docker = "*"
|
||||
chromadb = "*"
|
||||
docker = "==6.1.3"
|
||||
chromadb = "==0.6.3"
|
||||
langchain = "==0.3.20"
|
||||
langgraph = "==0.3.5"
|
||||
langchain-community = "==0.3.19"
|
||||
tavily-python = "==0.5.1"
|
||||
langchain_huggingface = "==0.1.2"
|
||||
langchain-deepseek = "==0.1.2"
|
||||
torch = "*"
|
||||
sentence-transformers = "*"
|
||||
langchain = "*"
|
||||
langchain-deepseek = "*"
|
||||
|
||||
[dev-packages]
|
||||
httpx = "==0.26.0"
|
||||
|
||||
1312
app/backend/Pipfile.lock
generated
1312
app/backend/Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
from decouple import config
|
||||
from typing import List
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
from langchain_community.tools.tavily_search import TavilySearchResults
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain.prompts import PromptTemplate, HumanMessagePromptTemplate
|
||||
|
||||
from langchain.schema import Document
|
||||
from pprint import pprint
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
|
||||
from models.adaptive_rag.routing import RouteQuery
|
||||
from models.adaptive_rag.grading import (
|
||||
GradeDocuments,
|
||||
GradeHallucinations,
|
||||
GradeAnswer,
|
||||
)
|
||||
from models.adaptive_rag.query import (
|
||||
QueryRequest,
|
||||
QueryResponse,
|
||||
)
|
||||
|
||||
from models.adaptive_rag.prompts_library import (
|
||||
system_router,
|
||||
system_retriever_grader,
|
||||
system_hallucination_grader,
|
||||
system_answer_grader,
|
||||
system_question_rewriter,
|
||||
qa_prompt_template
|
||||
)
|
||||
|
||||
from .utils import ConnectionManager
|
||||
|
||||
@ -17,8 +53,11 @@ os.environ["TAVILY_API_KEY"] = config(
|
||||
"TAVILY_API_KEY", cast=str, default="tvly-dev-wXXXXXX"
|
||||
)
|
||||
|
||||
# Initialize embedding model (do this ONCE)
|
||||
embedding_model = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-mpnet-base-v2")
|
||||
|
||||
# Initialize the DeepSeek chat model
|
||||
llm_chat = ChatDeepSeek(
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
@ -26,10 +65,358 @@ llm_chat = ChatDeepSeek(
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# Load data from ChromaDB
|
||||
db_folder = "chroma_db"
|
||||
db_path = os.path.join(os.getcwd(), db_folder)
|
||||
collection_name = "rag-chroma"
|
||||
vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
|
||||
############################ LLM functions ############################
|
||||
# Routing to vectorstore or web search
|
||||
structured_llm_router = llm.with_structured_output(RouteQuery)
|
||||
# Prompt
|
||||
route_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_router),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
question_router = route_prompt | structured_llm_router
|
||||
|
||||
### Retrieval Grader
|
||||
structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
||||
# Prompt
|
||||
grade_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_retriever_grader),
|
||||
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
|
||||
]
|
||||
)
|
||||
retrieval_grader = grade_prompt | structured_llm_grader
|
||||
|
||||
### Generate
|
||||
# Create a PromptTemplate with the given prompt
|
||||
new_prompt_template = PromptTemplate(
|
||||
input_variables=["context", "question"],
|
||||
template=qa_prompt_template,
|
||||
)
|
||||
|
||||
# Create a new HumanMessagePromptTemplate with the new PromptTemplate
|
||||
new_human_message_prompt_template = HumanMessagePromptTemplate(
|
||||
prompt=new_prompt_template
|
||||
)
|
||||
prompt_qa = ChatPromptTemplate.from_messages([new_human_message_prompt_template])
|
||||
|
||||
# Chain
|
||||
rag_chain = prompt_qa | llm | StrOutputParser()
|
||||
|
||||
|
||||
### Hallucination Grader
|
||||
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
|
||||
|
||||
# Prompt
|
||||
hallucination_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_hallucination_grader),
|
||||
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
|
||||
]
|
||||
)
|
||||
|
||||
hallucination_grader = hallucination_prompt | structured_llm_grader
|
||||
|
||||
### Answer Grader
|
||||
structured_llm_grader = llm.with_structured_output(GradeAnswer)
|
||||
|
||||
# Prompt
|
||||
answer_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_answer_grader),
|
||||
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
|
||||
]
|
||||
)
|
||||
answer_grader = answer_prompt | structured_llm_grader
|
||||
|
||||
### Question Re-writer
|
||||
# Prompt
|
||||
re_write_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system_question_rewriter),
|
||||
(
|
||||
"human",
|
||||
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
|
||||
),
|
||||
]
|
||||
)
|
||||
question_rewriter = re_write_prompt | llm | StrOutputParser()
|
||||
|
||||
### Search
|
||||
web_search_tool = TavilySearchResults(k=3)
|
||||
|
||||
############### Graph functions ################
|
||||
|
||||
def retrieve(state):
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE---")
|
||||
question = state["question"]
|
||||
|
||||
# Retrieval
|
||||
documents = retriever.invoke(question)
|
||||
|
||||
print(documents)
|
||||
return {"documents": documents, "question": question}
|
||||
|
||||
|
||||
def generate(state):
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, generation, that contains LLM generation
|
||||
"""
|
||||
print("---GENERATE---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
|
||||
# RAG generation
|
||||
generation = rag_chain.invoke({"context": documents, "question": question})
|
||||
return {"documents": documents, "question": question, "generation": generation}
|
||||
|
||||
|
||||
def grade_documents(state):
|
||||
"""
|
||||
Determines whether the retrieved documents are relevant to the question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Updates documents key with only filtered relevant documents
|
||||
"""
|
||||
|
||||
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
|
||||
# Score each doc
|
||||
filtered_docs = []
|
||||
for d in documents:
|
||||
score = retrieval_grader.invoke(
|
||||
{"question": question, "document": d.page_content}
|
||||
)
|
||||
grade = score.binary_score
|
||||
if grade == "yes":
|
||||
print("---GRADE: DOCUMENT RELEVANT---")
|
||||
filtered_docs.append(d)
|
||||
else:
|
||||
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
||||
continue
|
||||
return {"documents": filtered_docs, "question": question}
|
||||
|
||||
|
||||
def transform_query(state):
|
||||
"""
|
||||
Transform the query to produce a better question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Updates question key with a re-phrased question
|
||||
"""
|
||||
|
||||
print("---TRANSFORM QUERY---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
|
||||
# Re-write question
|
||||
better_question = question_rewriter.invoke({"question": question})
|
||||
return {"documents": documents, "question": better_question}
|
||||
|
||||
|
||||
def web_search(state):
|
||||
"""
|
||||
Web search based on the re-phrased question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Updates documents key with appended web results
|
||||
"""
|
||||
|
||||
print("---WEB SEARCH---")
|
||||
question = state["question"]
|
||||
|
||||
# Web search
|
||||
docs = web_search_tool.invoke({"query": question})
|
||||
web_results = "\n".join([d["content"] for d in docs])
|
||||
web_results = Document(page_content=web_results)
|
||||
|
||||
return {"documents": web_results, "question": question}
|
||||
|
||||
|
||||
### Edges ###
|
||||
def route_question(state):
|
||||
"""
|
||||
Route question to web search or RAG.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Next node to call
|
||||
"""
|
||||
|
||||
print("---ROUTE QUESTION---")
|
||||
question = state["question"]
|
||||
source = question_router.invoke({"question": question})
|
||||
if source.datasource == "web_search":
|
||||
print("---ROUTE QUESTION TO WEB SEARCH---")
|
||||
return "web_search"
|
||||
elif source.datasource == "vectorstore":
|
||||
print("---ROUTE QUESTION TO RAG---")
|
||||
return "vectorstore"
|
||||
|
||||
|
||||
def decide_to_generate(state):
|
||||
"""
|
||||
Determines whether to generate an answer, or re-generate a question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Binary decision for next node to call
|
||||
"""
|
||||
|
||||
print("---ASSESS GRADED DOCUMENTS---")
|
||||
state["question"]
|
||||
filtered_documents = state["documents"]
|
||||
|
||||
if not filtered_documents:
|
||||
# All documents have been filtered check_relevance
|
||||
# We will re-generate a new query
|
||||
print(
|
||||
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
|
||||
)
|
||||
return "transform_query"
|
||||
else:
|
||||
# We have relevant documents, so generate answer
|
||||
print("---DECISION: GENERATE---")
|
||||
return "generate"
|
||||
|
||||
|
||||
def grade_generation_v_documents_and_question(state):
|
||||
"""
|
||||
Determines whether the generation is grounded in the document and answers question.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Decision for next node to call
|
||||
"""
|
||||
|
||||
print("---CHECK HALLUCINATIONS---")
|
||||
question = state["question"]
|
||||
documents = state["documents"]
|
||||
generation = state["generation"]
|
||||
|
||||
score = hallucination_grader.invoke(
|
||||
{"documents": documents, "generation": generation}
|
||||
)
|
||||
grade = score.binary_score
|
||||
|
||||
# Check hallucination
|
||||
if grade == "yes":
|
||||
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
||||
# Check question-answering
|
||||
print("---GRADE GENERATION vs QUESTION---")
|
||||
score = answer_grader.invoke({"question": question, "generation": generation})
|
||||
grade = score.binary_score
|
||||
if grade == "yes":
|
||||
print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
||||
return "useful"
|
||||
else:
|
||||
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
||||
return "not useful"
|
||||
else:
|
||||
pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
||||
return "not supported"
|
||||
|
||||
|
||||
class GraphState(TypedDict):
|
||||
"""
|
||||
Represents the state of our graph.
|
||||
|
||||
Attributes:
|
||||
question: question
|
||||
generation: LLM generation
|
||||
documents: list of documents
|
||||
"""
|
||||
|
||||
question: str
|
||||
generation: str
|
||||
documents: List[str]
|
||||
|
||||
workflow = StateGraph(GraphState)
|
||||
|
||||
# Define the nodes
|
||||
workflow.add_node("web_search", web_search) # web search
|
||||
workflow.add_node("retrieve", retrieve) # retrieve
|
||||
workflow.add_node("grade_documents", grade_documents) # grade documents
|
||||
workflow.add_node("generate", generate) # generatae
|
||||
workflow.add_node("transform_query", transform_query) # transform_query
|
||||
|
||||
# Build graph
|
||||
workflow.add_conditional_edges(
|
||||
START,
|
||||
route_question,
|
||||
{
|
||||
"web_search": "web_search",
|
||||
"vectorstore": "retrieve",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("web_search", "generate")
|
||||
workflow.add_edge("retrieve", "grade_documents")
|
||||
workflow.add_conditional_edges(
|
||||
"grade_documents",
|
||||
decide_to_generate,
|
||||
{
|
||||
"transform_query": "transform_query",
|
||||
"generate": "generate",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("transform_query", "retrieve")
|
||||
workflow.add_conditional_edges(
|
||||
"generate",
|
||||
grade_generation_v_documents_and_question,
|
||||
{
|
||||
"not supported": "generate",
|
||||
"useful": END,
|
||||
"not useful": "transform_query",
|
||||
},
|
||||
)
|
||||
|
||||
# Compile
|
||||
app = workflow.compile()
|
||||
|
||||
# Initialize the connection manager
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await manager.connect(websocket)
|
||||
@ -44,9 +431,65 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
and len(data_json) > 0
|
||||
and "content" in data_json[0]
|
||||
):
|
||||
async for chunk in llm_chat.astream(data_json[0]["content"]):
|
||||
inputs = {
|
||||
"question": data_json[0]["content"]
|
||||
}
|
||||
async for chunk in app.astream(inputs):
|
||||
# Determine if chunk is intermediate or final
|
||||
if isinstance(chunk, dict):
|
||||
if len(chunk) == 1:
|
||||
step_name = list(chunk.keys())[0]
|
||||
step_value = chunk[step_name]
|
||||
# Check if this step contains the final answer
|
||||
if isinstance(step_value, dict) and 'generation' in step_value:
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "message", "payload": chunk.content}),
|
||||
json.dumps({
|
||||
"type": "final",
|
||||
"title": "Answer",
|
||||
"payload": step_value['generation']
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
else:
|
||||
await manager.send_personal_message(
|
||||
json.dumps({
|
||||
"type": "intermediate",
|
||||
"title": step_name.replace('_', ' ').title(),
|
||||
"payload": str(step_value)
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
elif 'generation' in chunk:
|
||||
await manager.send_personal_message(
|
||||
json.dumps({
|
||||
"type": "final",
|
||||
"title": "Answer",
|
||||
"payload": chunk['generation']
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
else:
|
||||
await manager.send_personal_message(
|
||||
json.dumps({
|
||||
"type": "intermediate",
|
||||
"title": "Step",
|
||||
"payload": str(chunk)
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
else:
|
||||
# Fallback for non-dict chunks
|
||||
await manager.send_personal_message(
|
||||
json.dumps({
|
||||
"type": "intermediate",
|
||||
"title": "Step",
|
||||
"payload": str(chunk)
|
||||
}),
|
||||
websocket,
|
||||
)
|
||||
# Send a final 'done' message to signal completion
|
||||
await manager.send_personal_message(
|
||||
json.dumps({"type": "done"}),
|
||||
websocket,
|
||||
)
|
||||
else:
|
||||
@ -56,9 +499,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await manager.broadcast("Invalid JSON message")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
await manager.broadcast("Client disconnected")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
await manager.broadcast("Client disconnected")
|
||||
|
||||
@ -26,3 +26,13 @@ app.add_middleware(
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8004,
|
||||
reload=True
|
||||
)
|
||||
39
app/backend/models/adaptive_rag/prompts_library.py
Normal file
39
app/backend/models/adaptive_rag/prompts_library.py
Normal file
@ -0,0 +1,39 @@
|
||||
system_router = """You are an expert at routing a user question to a vectorstore or web search.
|
||||
The vectorstore contains documents related to any cancer/tumor disease. The question may be
|
||||
asked in a variety of languages, and may be phrased in a variety of ways.
|
||||
Use the vectorstore for questions on these topics. Otherwise, use web-search.
|
||||
"""
|
||||
|
||||
system_retriever_grader = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
||||
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
||||
You must make sure to read carefully that the document contains a sentence or chunk of sentences that is exactly related but not closely related to the question subject (e.g. must be the exact disease or subject in question). \n
|
||||
The goal is to filter out erroneous retrievals. \n
|
||||
Must return a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
||||
|
||||
system_hallucination_grader = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
|
||||
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
|
||||
|
||||
system_answer_grader = """You are a grader assessing whether an answer addresses / resolves a question \n
|
||||
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
|
||||
|
||||
system_question_rewriter = """You a question re-writer that converts an input question to a better version that is optimized \n
|
||||
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
|
||||
|
||||
# prompt for question answering based on retrieved documents
|
||||
qa_prompt_template = """You are an expert at answering questions based on the following retrieved context.\n
|
||||
Before answering the question, you must have your own thought process what are the general scopes to cover when answering this question, step-by-step. Do not include this thought process in the answer.\n
|
||||
Then, given your thought process, you must read the provided context carefully and extract the relevant information.\n
|
||||
|
||||
If the question is about medical question, you must answer the question in a medical way and assume that the audience is a junior doctor or a medical student: \n
|
||||
1. For cancer diseases, you must include comprehensive treatment advices that encompasses multidisciplinary treatment options that included but not limited to surgery, chemotherapy, radiology, internal medicine (drugs), nutritional ratio (protein), etc. You must layout out the treatment options like what are the first-line, second-line treatment etc.\n
|
||||
2. For cancer diseases, don't consider context that is not primary tumor/cancer related, unless the question specifically mention it is secondary tumor/cancer related.\n
|
||||
3. If the question didn't state the stage of the cancer disease, you must reply with treatment options for each stage of the cancer disease, if they are availalbe in the provided context. If they are not available in the provided context, give a general one.\n
|
||||
|
||||
You must not use any information that is not present in the provided context to answer the question. Make sure to remove those information not present in the provided context.\n
|
||||
If you don't know the answer, just say that you don't know.\n
|
||||
Provide the answer in a concise and organized manner. \n
|
||||
|
||||
Question: {question} \n
|
||||
Context: {context} \n
|
||||
Answer:
|
||||
"""
|
||||
51
app/backend/tests/api/test_chatbot.py
Normal file
51
app/backend/tests/api/test_chatbot.py
Normal file
@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi import WebSocket
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Patch langchain and other heavy dependencies for import
|
||||
sys.modules['langchain_deepseek'] = MagicMock()
|
||||
sys.modules['langchain_huggingface'] = MagicMock()
|
||||
sys.modules['langchain_community.vectorstores.chroma'] = MagicMock()
|
||||
sys.modules['langchain_community.tools.tavily_search'] = MagicMock()
|
||||
sys.modules['langchain_core.prompts'] = MagicMock()
|
||||
sys.modules['langchain_core.output_parsers'] = MagicMock()
|
||||
sys.modules['langchain.prompts'] = MagicMock()
|
||||
sys.modules['langchain.schema'] = MagicMock()
|
||||
sys.modules['langgraph.graph'] = MagicMock()
|
||||
|
||||
from api import chatbot
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from fastapi import FastAPI
|
||||
app = FastAPI()
|
||||
app.include_router(chatbot.router)
|
||||
return TestClient(app)
|
||||
|
||||
def test_router_exists():
|
||||
assert hasattr(chatbot, 'router')
|
||||
|
||||
def test_env_vars_loaded(monkeypatch):
|
||||
monkeypatch.setenv('DEEPSEEK_API_KEY', 'dummy')
|
||||
monkeypatch.setenv('TAVILY_API_KEY', 'dummy')
|
||||
# Re-import to trigger env loading
|
||||
import importlib
|
||||
importlib.reload(chatbot)
|
||||
assert True
|
||||
|
||||
def test_websocket_endpoint_accepts(monkeypatch):
|
||||
# Patch ConnectionManager
|
||||
mock_manager = MagicMock()
|
||||
monkeypatch.setattr(chatbot, 'manager', mock_manager)
|
||||
ws = MagicMock(spec=WebSocket)
|
||||
ws.receive_text = MagicMock(side_effect=[pytest.raises(StopIteration)])
|
||||
ws.accept = MagicMock()
|
||||
# Should not raise
|
||||
try:
|
||||
coro = chatbot.websocket_endpoint(ws)
|
||||
assert hasattr(coro, '__await__')
|
||||
except Exception as e:
|
||||
pytest.fail(f"websocket_endpoint raised: {e}")
|
||||
14
app/backend/tests/models/adaptive_rag/test_grading.py
Normal file
14
app/backend/tests/models/adaptive_rag/test_grading.py
Normal file
@ -0,0 +1,14 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import grading
|
||||
|
||||
def test_grade_documents_class():
|
||||
doc = grading.GradeDocuments(binary_score='yes')
|
||||
assert doc.binary_score == 'yes'
|
||||
|
||||
def test_grade_hallucinations_class():
|
||||
doc = grading.GradeHallucinations(binary_score='no')
|
||||
assert doc.binary_score == 'no'
|
||||
|
||||
def test_grade_answer_class():
|
||||
doc = grading.GradeAnswer(binary_score='yes')
|
||||
assert doc.binary_score == 'yes'
|
||||
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import prompts_library
|
||||
|
||||
def test_prompts_are_strings():
|
||||
assert isinstance(prompts_library.system_router, str)
|
||||
assert isinstance(prompts_library.system_retriever_grader, str)
|
||||
assert isinstance(prompts_library.system_hallucination_grader, str)
|
||||
assert isinstance(prompts_library.system_answer_grader, str)
|
||||
assert isinstance(prompts_library.system_question_rewriter, str)
|
||||
assert isinstance(prompts_library.qa_prompt_template, str)
|
||||
8
app/backend/tests/models/adaptive_rag/test_query.py
Normal file
8
app/backend/tests/models/adaptive_rag/test_query.py
Normal file
@ -0,0 +1,8 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import query
|
||||
|
||||
def test_query_request_and_response():
|
||||
req = query.QueryRequest(query="What is AI?")
|
||||
assert req.query == "What is AI?"
|
||||
resp = query.QueryResponse(response="Artificial Intelligence")
|
||||
assert resp.response == "Artificial Intelligence"
|
||||
6
app/backend/tests/models/adaptive_rag/test_routing.py
Normal file
6
app/backend/tests/models/adaptive_rag/test_routing.py
Normal file
@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
from models.adaptive_rag import routing
|
||||
|
||||
def test_route_query_class():
|
||||
route = routing.RouteQuery(datasource="vectorstore")
|
||||
assert route.datasource == "vectorstore"
|
||||
10
app/backend/tests/test_config_and_main.py
Normal file
10
app/backend/tests/test_config_and_main.py
Normal file
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
from importlib import import_module
|
||||
|
||||
def test_config_import():
|
||||
mod = import_module('config')
|
||||
assert mod is not None
|
||||
|
||||
def test_main_import():
|
||||
mod = import_module('main')
|
||||
assert mod is not None
|
||||
@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
from importlib import import_module
|
||||
|
||||
def test_initialize_sentence_transformer_import():
|
||||
mod = import_module('utils.initialize_sentence_transformer')
|
||||
assert mod is not None
|
||||
15
app/backend/utils/initialize_sentence_transformer.py
Normal file
15
app/backend/utils/initialize_sentence_transformer.py
Normal file
@ -0,0 +1,15 @@
|
||||
from decouple import config
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import os
|
||||
|
||||
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str, default="paraphrase-multilingual-mpnet-base-v2")
|
||||
|
||||
# Initialize embedding model
|
||||
model = SentenceTransformer(EMBEDDING_MODEL, device="cpu")
|
||||
|
||||
# create directory if not exists
|
||||
if not os.path.exists("./transformer_model"):
|
||||
os.makedirs("./transformer_model")
|
||||
|
||||
# save the model
|
||||
model.save("./transformer_model/paraphrase-multilingual-mpnet-base-v2")
|
||||
@ -37,13 +37,13 @@ services:
|
||||
environment:
|
||||
LOG_LEVEL: "DEBUG"
|
||||
|
||||
tests:
|
||||
build:
|
||||
context: ./tests
|
||||
container_name: tests-aimingmedai
|
||||
# depends_on:
|
||||
# - backend
|
||||
# - frontend
|
||||
environment:
|
||||
FRONTEND_URL: http://frontend:80
|
||||
BACKEND_URL: http://backend:80
|
||||
# tests:
|
||||
# build:
|
||||
# context: ./tests
|
||||
# container_name: tests-aimingmedai
|
||||
# # depends_on:
|
||||
# # - backend
|
||||
# # - frontend
|
||||
# environment:
|
||||
# FRONTEND_URL: http://frontend:80
|
||||
# BACKEND_URL: http://backend:80
|
||||
|
||||
@ -12,7 +12,7 @@ ARG ENV_FILE=.env.test
|
||||
COPY ${ENV_FILE} /usr/src/app/.env
|
||||
|
||||
# Copy dependency files and install dependencies
|
||||
RUN npm install && npm i --save-dev @types/jest
|
||||
RUN npm install && npm install --save-dev @types/jest
|
||||
|
||||
EXPOSE 80
|
||||
CMD [ "npm", "run", "dev", "--", "--host", "0.0.0.0", "--port", "80" ]
|
||||
1877
app/frontend/package-lock.json
generated
1877
app/frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@ -12,9 +12,12 @@
|
||||
"test:run": "vitest run"
|
||||
},
|
||||
"dependencies": {
|
||||
"@tailwindcss/typography": "^0.5.16",
|
||||
"daisyui": "^5.0.17",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0"
|
||||
"react-dom": "^19.0.0",
|
||||
"react-markdown": "^10.1.0",
|
||||
"remark-gfm": "^4.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.21.0",
|
||||
|
||||
@ -1,19 +1,31 @@
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
|
||||
const BASE_DOMAIN_NAME_PORT = import.meta.env.REACT_APP_DOMAIN_NAME_PORT || 'localhost:8004';
|
||||
|
||||
|
||||
interface Message {
|
||||
sender: 'user' | 'bot';
|
||||
text: string;
|
||||
}
|
||||
|
||||
interface ChatTurn {
|
||||
question: string;
|
||||
intermediateMessages: { title: string; payload: string }[];
|
||||
finalAnswer: string | null;
|
||||
isLoading: boolean;
|
||||
showIntermediate: boolean;
|
||||
}
|
||||
|
||||
const App: React.FC = () => {
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [chatTurns, setChatTurns] = useState<ChatTurn[]>([]);
|
||||
const [newMessage, setNewMessage] = useState('');
|
||||
const [socket, setSocket] = useState<WebSocket | null>(null);
|
||||
const mounted = useRef(false);
|
||||
|
||||
// Disable input/button if any job is running
|
||||
const isJobRunning = chatTurns.some(turn => turn.isLoading);
|
||||
|
||||
useEffect(() => {
|
||||
mounted.current = true;
|
||||
const ws = new WebSocket(`ws://${BASE_DOMAIN_NAME_PORT}/ws`);
|
||||
@ -24,18 +36,40 @@ const App: React.FC = () => {
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.type === 'message' && data.payload && mounted.current) {
|
||||
setMessages((prevMessages) => {
|
||||
const lastMessage = prevMessages[prevMessages.length - 1];
|
||||
if (lastMessage && lastMessage.sender === 'bot') {
|
||||
return [...prevMessages.slice(0, -1), { ...lastMessage, text: lastMessage.text + data.payload }];
|
||||
} else {
|
||||
return [...prevMessages, { sender: 'bot', text: data.payload }];
|
||||
setChatTurns((prevTurns) => {
|
||||
if (prevTurns.length === 0) return prevTurns;
|
||||
const lastTurn = prevTurns[prevTurns.length - 1];
|
||||
if (data.type === 'intermediate') {
|
||||
// Add intermediate message to the last turn
|
||||
const updatedTurn = {
|
||||
...lastTurn,
|
||||
intermediateMessages: [...lastTurn.intermediateMessages, { title: data.title, payload: data.payload }],
|
||||
};
|
||||
return [...prevTurns.slice(0, -1), updatedTurn];
|
||||
} else if (data.type === 'final') {
|
||||
// Set final answer for the last turn
|
||||
const updatedTurn = {
|
||||
...lastTurn,
|
||||
finalAnswer: data.payload,
|
||||
};
|
||||
return [...prevTurns.slice(0, -1), updatedTurn];
|
||||
} else if (data.type === 'done') {
|
||||
// Mark last turn as not loading
|
||||
const updatedTurn = {
|
||||
...lastTurn,
|
||||
isLoading: false,
|
||||
};
|
||||
return [...prevTurns.slice(0, -1), updatedTurn];
|
||||
} else if (data.type === 'message' && data.payload && mounted.current) {
|
||||
// legacy support, treat as final
|
||||
const updatedTurn = {
|
||||
...lastTurn,
|
||||
finalAnswer: (lastTurn.finalAnswer || '') + data.payload,
|
||||
};
|
||||
return [...prevTurns.slice(0, -1), updatedTurn];
|
||||
}
|
||||
return prevTurns;
|
||||
});
|
||||
} else {
|
||||
console.error('Unexpected message format:', data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error parsing message:', error);
|
||||
}
|
||||
@ -54,23 +88,92 @@ const App: React.FC = () => {
|
||||
|
||||
const sendMessage = () => {
|
||||
if (newMessage.trim() !== '') {
|
||||
setChatTurns((prev) => [
|
||||
...prev,
|
||||
{
|
||||
question: newMessage,
|
||||
intermediateMessages: [],
|
||||
finalAnswer: null,
|
||||
isLoading: true,
|
||||
showIntermediate: false,
|
||||
},
|
||||
]);
|
||||
const message = [{ role: 'user', content: newMessage }];
|
||||
setMessages((prevMessages) => [...prevMessages, { sender: 'user', text: newMessage }]);
|
||||
socket?.send(JSON.stringify(message));
|
||||
setNewMessage('');
|
||||
}
|
||||
};
|
||||
|
||||
const toggleShowIntermediate = (idx: number) => {
|
||||
setChatTurns((prev) => prev.map((turn, i) => i === idx ? { ...turn, showIntermediate: !turn.showIntermediate } : turn));
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-screen bg-gray-100">
|
||||
<div className="p-4">
|
||||
<h1 className="text-3xl font-bold text-center text-gray-800">Simple Chatbot</h1>
|
||||
</div>
|
||||
<div className="flex-grow overflow-y-auto p-4">
|
||||
{messages.map((msg, index) => (
|
||||
<div key={index} className={`p-4 rounded-lg mb-2 ${msg.sender === 'user' ? 'bg-blue-100 text-blue-800' : 'bg-gray-200 text-gray-800'}`}>
|
||||
{msg.text}
|
||||
{chatTurns.map((turn, idx) => (
|
||||
<React.Fragment key={idx}>
|
||||
{/* User question */}
|
||||
<div className="p-4 rounded-lg mb-2 bg-blue-100 text-blue-800">{turn.question}</div>
|
||||
{/* Status box for this question */}
|
||||
{turn.intermediateMessages.length > 0 && (
|
||||
<div className="mb-4">
|
||||
<div className="bg-blue-50 border border-blue-300 rounded-lg p-3 shadow-sm flex items-center">
|
||||
{/* Spinner icon */}
|
||||
{turn.isLoading && (
|
||||
<svg className="animate-spin h-5 w-5 text-blue-500 mr-2" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
||||
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
|
||||
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8v8z"></path>
|
||||
</svg>
|
||||
)}
|
||||
<span className="font-semibold text-blue-700 mr-2">Working on:</span>
|
||||
{/* Key steps summary */}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{turn.intermediateMessages.map((msg, i) => (
|
||||
<span key={i} className="bg-blue-100 text-blue-700 px-2 py-1 rounded text-xs font-medium border border-blue-200">
|
||||
{msg.title}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
<button
|
||||
className="ml-auto text-xs text-blue-600 flex items-center gap-1 px-2 py-1 rounded hover:bg-blue-100 focus:outline-none border border-transparent focus:border-blue-300 transition"
|
||||
onClick={() => toggleShowIntermediate(idx)}
|
||||
aria-expanded={turn.showIntermediate}
|
||||
title={turn.showIntermediate ? 'Hide details' : 'Show details'}
|
||||
>
|
||||
<svg
|
||||
className={`w-4 h-4 transition-transform duration-200 ${turn.showIntermediate ? 'rotate-180' : ''}`}
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth="2" d="M19 9l-7 7-7-7" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/* Expanded details */}
|
||||
{turn.showIntermediate && (
|
||||
<div className="bg-white border border-blue-200 rounded-b-lg p-3 mt-1 text-xs max-h-64 overflow-y-auto">
|
||||
{turn.intermediateMessages.map((msg, i) => (
|
||||
<div key={i} className="mb-3">
|
||||
<div className="font-bold text-blue-700 mb-1">{msg.title}</div>
|
||||
<pre className="whitespace-pre-wrap break-words text-gray-800">{msg.payload}</pre>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{/* Final answer for this question */}
|
||||
{turn.finalAnswer && (
|
||||
<div className="prose p-4 rounded-lg mb-2 bg-gray-200 text-gray-800">
|
||||
<ReactMarkdown remarkPlugins={[remarkGfm]}>{turn.finalAnswer}</ReactMarkdown> </div>
|
||||
)}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>
|
||||
<div className="p-4 border-t border-gray-300">
|
||||
@ -80,8 +183,13 @@ const App: React.FC = () => {
|
||||
value={newMessage}
|
||||
onChange={(e) => setNewMessage(e.target.value)}
|
||||
className="flex-grow p-2 border border-gray-300 rounded-lg mr-2"
|
||||
disabled={isJobRunning}
|
||||
/>
|
||||
<button onClick={sendMessage} className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded-lg">
|
||||
<button
|
||||
onClick={sendMessage}
|
||||
className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded-lg"
|
||||
disabled={isJobRunning}
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
</div>
|
||||
|
||||
@ -6,6 +6,9 @@ export default {
|
||||
theme: {
|
||||
extend: {},
|
||||
},
|
||||
plugins: [require("daisyui")],
|
||||
plugins: [
|
||||
require('@tailwindcss/typography'),
|
||||
require("daisyui"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ export default defineConfig({
|
||||
server: {
|
||||
host: true,
|
||||
strictPort: true,
|
||||
port:
|
||||
port: 8004
|
||||
},
|
||||
test: {
|
||||
globals: true,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user