This commit is contained in:
leehk 2025-04-24 12:58:46 +08:00
parent aaaf0f4242
commit 1ec8df8cec
7 changed files with 1368 additions and 317 deletions

View File

@ -25,19 +25,29 @@ COPY . $APP_HOME
# install python dependencies
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple pipenv && rm -rf ~/.cache/pip
RUN pipenv install --deploy --dev
RUN pipenv install --index https://download.pytorch.org/whl/cpu torch
# chown all the files to the app user
# Create cache directory and set permissions
RUN mkdir -p /home/app/.cache/huggingface
RUN chown -R app:app /home/app/.cache/huggingface
RUN chown -R app:app $APP_HOME
# change to the app user
USER app
# pytest
RUN pipenv run pytest tests --disable-warnings
# Run python to initialize download of SentenceTransformer model
RUN pipenv run python utils/initialize_sentence_transformer.py
# # pytest
# RUN pipenv run pytest tests --disable-warnings
# expose the port the app runs on
EXPOSE 80
# run uvicorn
CMD ["pipenv", "run", "uvicorn", "main:app", "--reload", "--workers", "1", "--host", "0.0.0.0", "--port", "80"]

View File

@ -7,15 +7,19 @@ name = "pypi"
fastapi = "==0.115.9"
starlette = "==0.45.3"
uvicorn = {version = "==0.26.0", extras = ["standard"]}
pydantic-settings = "==2.1.0"
pydantic-settings = "*"
gunicorn = "==21.0.1"
python-decouple = "==3.8"
pyyaml = "==6.0.1"
docker = "*"
chromadb = "*"
sentence-transformers = "*"
langchain = "*"
langchain-deepseek = "*"
docker = "==6.1.3"
chromadb = "==0.6.3"
sentence-transformers = "==3.4.1"
langchain = "==0.3.20"
langgraph = "==0.3.5"
langchain-community = "==0.3.19"
tavily-python = "==0.5.1"
langchain_huggingface = "==0.1.2"
langchain-deepseek = "==0.1.2"
[dev-packages]
httpx = "==0.26.0"

1127
app/backend/Pipfile.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,45 @@
import json
import os
import argparse
import shutil
from decouple import config
from typing import List
from typing_extensions import TypedDict
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from langchain_deepseek import ChatDeepSeek
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores.chroma import Chroma
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import PromptTemplate, HumanMessagePromptTemplate
from langchain.schema import Document
from pprint import pprint
from langgraph.graph import END, StateGraph, START
from models.adaptive_rag.routing import RouteQuery
from models.adaptive_rag.grading import (
GradeDocuments,
GradeHallucinations,
GradeAnswer,
)
from models.adaptive_rag.query import (
QueryRequest,
QueryResponse,
)
from models.adaptive_rag.prompts_library import (
system_router,
system_retriever_grader,
system_hallucination_grader,
system_answer_grader,
system_question_rewriter,
qa_prompt_template
)
from .utils import ConnectionManager
@ -17,8 +53,11 @@ os.environ["TAVILY_API_KEY"] = config(
"TAVILY_API_KEY", cast=str, default="tvly-dev-wXXXXXX"
)
# Initialize embedding model (do this ONCE)
embedding_model = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-mpnet-base-v2")
# Initialize the DeepSeek chat model
llm_chat = ChatDeepSeek(
llm = ChatDeepSeek(
model="deepseek-chat",
temperature=0,
max_tokens=None,
@ -26,10 +65,358 @@ llm_chat = ChatDeepSeek(
max_retries=2,
)
# Load data from ChromaDB
db_folder = "chroma_db"
db_path = os.path.join(os.getcwd(), db_folder)
collection_name = "rag-chroma"
vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model)
retriever = vectorstore.as_retriever()
############################ LLM functions ############################
# Routing to vectorstore or web search
structured_llm_router = llm.with_structured_output(RouteQuery)
# Prompt
route_prompt = ChatPromptTemplate.from_messages(
[
("system", system_router),
("human", "{question}"),
]
)
question_router = route_prompt | structured_llm_router
### Retrieval Grader
structured_llm_grader = llm.with_structured_output(GradeDocuments)
# Prompt
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system_retriever_grader),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
### Generate
# Create a PromptTemplate with the given prompt
new_prompt_template = PromptTemplate(
input_variables=["context", "question"],
template=qa_prompt_template,
)
# Create a new HumanMessagePromptTemplate with the new PromptTemplate
new_human_message_prompt_template = HumanMessagePromptTemplate(
prompt=new_prompt_template
)
prompt_qa = ChatPromptTemplate.from_messages([new_human_message_prompt_template])
# Chain
rag_chain = prompt_qa | llm | StrOutputParser()
### Hallucination Grader
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
# Prompt
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system_hallucination_grader),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
]
)
hallucination_grader = hallucination_prompt | structured_llm_grader
### Answer Grader
structured_llm_grader = llm.with_structured_output(GradeAnswer)
# Prompt
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system_answer_grader),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grader
### Question Re-writer
# Prompt
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system_question_rewriter),
(
"human",
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
### Search
web_search_tool = TavilySearchResults(k=3)
############### Graph functions ################
def retrieve(state):
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
question = state["question"]
# Retrieval
documents = retriever.invoke(question)
print(documents)
return {"documents": documents, "question": question}
def generate(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
# RAG generation
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with only filtered relevant documents
"""
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
# Score each doc
filtered_docs = []
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
continue
return {"documents": filtered_docs, "question": question}
def transform_query(state):
"""
Transform the query to produce a better question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates question key with a re-phrased question
"""
print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]
# Re-write question
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
def web_search(state):
"""
Web search based on the re-phrased question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with appended web results
"""
print("---WEB SEARCH---")
question = state["question"]
# Web search
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
return {"documents": web_results, "question": question}
### Edges ###
def route_question(state):
"""
Route question to web search or RAG.
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
print("---ROUTE QUESTION---")
question = state["question"]
source = question_router.invoke({"question": question})
if source.datasource == "web_search":
print("---ROUTE QUESTION TO WEB SEARCH---")
return "web_search"
elif source.datasource == "vectorstore":
print("---ROUTE QUESTION TO RAG---")
return "vectorstore"
def decide_to_generate(state):
"""
Determines whether to generate an answer, or re-generate a question.
Args:
state (dict): The current graph state
Returns:
str: Binary decision for next node to call
"""
print("---ASSESS GRADED DOCUMENTS---")
state["question"]
filtered_documents = state["documents"]
if not filtered_documents:
# All documents have been filtered check_relevance
# We will re-generate a new query
print(
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
)
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
def grade_generation_v_documents_and_question(state):
"""
Determines whether the generation is grounded in the document and answers question.
Args:
state (dict): The current graph state
Returns:
str: Decision for next node to call
"""
print("---CHECK HALLUCINATIONS---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
# Check hallucination
if grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
# Check question-answering
print("---GRADE GENERATION vs QUESTION---")
score = answer_grader.invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return "useful"
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return "not useful"
else:
pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return "not supported"
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
question: question
generation: LLM generation
documents: list of documents
"""
question: str
generation: str
documents: List[str]
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("web_search", web_search) # web search
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae
workflow.add_node("transform_query", transform_query) # transform_query
# Build graph
workflow.add_conditional_edges(
START,
route_question,
{
"web_search": "web_search",
"vectorstore": "retrieve",
},
)
workflow.add_edge("web_search", "generate")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
"generate",
grade_generation_v_documents_and_question,
{
"not supported": "generate",
"useful": END,
"not useful": "transform_query",
},
)
# Compile
app = workflow.compile()
# Initialize the connection manager
manager = ConnectionManager()
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
@ -44,7 +431,7 @@ async def websocket_endpoint(websocket: WebSocket):
and len(data_json) > 0
and "content" in data_json[0]
):
async for chunk in llm_chat.astream(data_json[0]["content"]):
async for chunk in llm.astream(data_json[0]["content"]):
await manager.send_personal_message(
json.dumps({"type": "message", "payload": chunk.content}),
websocket,
@ -56,9 +443,11 @@ async def websocket_endpoint(websocket: WebSocket):
except json.JSONDecodeError:
await manager.broadcast("Invalid JSON message")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast("Client disconnected")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast("Client disconnected")

View File

@ -0,0 +1,98 @@
system_router = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to any cancer/tumor disease. The question may be
asked in a variety of languages, and may be phrased in a variety of ways.
Use the vectorstore for questions on these topics. Otherwise, use web-search.
"""
system_retriever_grader = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
You must make sure to read carefully that the document contains a sentence or chunk of sentences that is exactly related but not closely related to the question subject (e.g. must be the exact disease or subject in question). \n
The goal is to filter out erroneous retrievals. \n
Must return a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
system_hallucination_grader = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
system_answer_grader = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
system_question_rewriter = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
# prompt for question answering based on retrieved documents
qa_prompt_template = """You are an expert at answering questions based on the following retrieved context.\n
Before answering the question, you must have your own thought process what are the general scopes to cover when answering this question, step-by-step. Do not include this thought process in the answer.\n
Then, given your thought process, you must read the provided context carefully and extract the relevant information.\n
If the question is about medical question, you must answer the question in a medical way and assume that the audience is a junior doctor or a medical student: \n
1. For cancer diseases, you must include comprehensive treatment advices that encompasses multidisciplinary treatment options that included but not limited to surgery, chemotherapy, radiology, internal medicine (drugs), nutritional ratio (protein), etc. You must layout out the treatment options like what are the first-line, second-line treatment etc.\n
2. For cancer diseases, don't consider context that is not primary tumor/cancer related, unless the question specifically mention it is secondary tumor/cancer related.\n
3. If the question didn't state the stage of the cancer disease, you must reply with treatment options for each stage of the cancer disease, if they are availalbe in the provided context. If they are not available in the provided context, give a general one.\n
You must not use any information that is not present in the provided context to answer the question. Make sure to remove those information not present in the provided context.\n
If you don't know the answer, just say that you don't know.\n
Provide the answer in a concise and organized manner. \n
Question: {question} \n
Context: {context} \n
Answer:
"""
# Evaluation
CORRECTNESS_PROMPT = """You are an impartial judge. Evaluate Student Answer against Ground Truth for conceptual similarity and correctness.
You may also be given additional information that was used by the model to generate the output.
Your task is to determine a numerical score called correctness based on the Student Answer and Ground Truth.
A definition of correctness and a grading rubric are provided below.
You must use the grading rubric to determine your score.
Metric definition:
Correctness assesses the degree to which a provided Student Answer aligns with factual accuracy, completeness, logical
consistency, and precise terminology of the Ground Truth. It evaluates the intrinsic validity of the Student Answer , independent of any
external context. A higher score indicates a higher adherence to factual accuracy, completeness, logical consistency,
and precise terminology of the Ground Truth.
Grading rubric:
Correctness: Below are the details for different scores:
- 1: Major factual errors, highly incomplete, illogical, and uses incorrect terminology.
- 2: Significant factual errors, incomplete, noticeable logical flaws, and frequent terminology errors.
- 3: Minor factual errors, somewhat incomplete, minor logical inconsistencies, and occasional terminology errors.
- 4: Few to no factual errors, mostly complete, strong logical consistency, and accurate terminology.
- 5: Accurate, complete, logically consistent, and uses precise terminology.
Reminder:
- Carefully read the Student Answer and Ground Truth
- Check for factual accuracy and completeness of Student Answer compared to the Ground Truth
- Focus on correctness of information rather than style or verbosity
- The goal is to evaluate factual correctness and completeness of the Student Answer.
- Please provide your answer score only with the numerical number between 1 and 5. No score: or other text is allowed.
"""
FAITHFULNESS_PROMPT = """You are an impartial judge. Evaluate output against context for faithfulness.
You may also be given additional information that was used by the model to generate the Output.
Your task is to determine a numerical score called faithfulness based on the output and context.
A definition of faithfulness and a grading rubric are provided below.
You must use the grading rubric to determine your score.
Metric definition:
Faithfulness is only evaluated with the provided output and context. Faithfulness assesses how much of the
provided output is factually consistent with the provided context. A higher score indicates that a higher proportion of
claims present in the output can be derived from the provided context. Faithfulness does not consider how much extra
information from the context is not present in the output.
Grading rubric:
Faithfulness: Below are the details for different scores:
- Score 1: None of the claims in the output can be inferred from the provided context.
- Score 2: Some of the claims in the output can be inferred from the provided context, but the majority of the output is missing from, inconsistent with, or contradictory to the provided context.
- Score 3: Half or more of the claims in the output can be inferred from the provided context.
- Score 4: Most of the claims in the output can be inferred from the provided context, with very little information that is not directly supported by the provided context.
- Score 5: All of the claims in the output are directly supported by the provided context, demonstrating high faithfulness to the provided context.
Reminder:
- Carefully read the output and context
- Focus on the information instead of the writing style or verbosity.
- Please provide your answer score only with the numerical number between 1 and 5, according to the grading rubric above. No score: or other text is allowed.
"""

View File

@ -0,0 +1,15 @@
from decouple import config
from sentence_transformers import SentenceTransformer
import os
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str, default="paraphrase-multilingual-mpnet-base-v2")
# Initialize embedding model
model = SentenceTransformer(EMBEDDING_MODEL)
# create directory if not exists
if not os.path.exists("./transformer_model"):
os.makedirs("./transformer_model")
# save the model
model.save("./transformer_model/paraphrase-multilingual-mpnet-base-v2")

View File

@ -37,13 +37,13 @@ services:
environment:
LOG_LEVEL: "DEBUG"
tests:
build:
context: ./tests
container_name: tests-aimingmedai
# depends_on:
# - backend
# - frontend
environment:
FRONTEND_URL: http://frontend:80
BACKEND_URL: http://backend:80
# tests:
# build:
# context: ./tests
# container_name: tests-aimingmedai
# # depends_on:
# # - backend
# # - frontend
# environment:
# FRONTEND_URL: http://frontend:80
# BACKEND_URL: http://backend:80