mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-19 13:23:23 +08:00
462 lines
13 KiB
Python
462 lines
13 KiB
Python
import json
|
|
import os
|
|
import argparse
|
|
import shutil
|
|
|
|
from decouple import config
|
|
from typing import List
|
|
from typing_extensions import TypedDict
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
from langchain_deepseek import ChatDeepSeek
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
from langchain_community.vectorstores.chroma import Chroma
|
|
from langchain_community.tools.tavily_search import TavilySearchResults
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain.prompts import PromptTemplate, HumanMessagePromptTemplate
|
|
|
|
from langchain.schema import Document
|
|
from pprint import pprint
|
|
from langgraph.graph import END, StateGraph, START
|
|
|
|
from models.adaptive_rag.routing import RouteQuery
|
|
from models.adaptive_rag.grading import (
|
|
GradeDocuments,
|
|
GradeHallucinations,
|
|
GradeAnswer,
|
|
)
|
|
from models.adaptive_rag.query import (
|
|
QueryRequest,
|
|
QueryResponse,
|
|
)
|
|
|
|
from models.adaptive_rag.prompts_library import (
|
|
system_router,
|
|
system_retriever_grader,
|
|
system_hallucination_grader,
|
|
system_answer_grader,
|
|
system_question_rewriter,
|
|
qa_prompt_template
|
|
)
|
|
|
|
from .utils import ConnectionManager
|
|
|
|
router = APIRouter()
|
|
|
|
# Load environment variables
|
|
os.environ["DEEPSEEK_API_KEY"] = config(
|
|
"DEEPSEEK_API_KEY", cast=str, default="sk-XXXXXXXXXX"
|
|
)
|
|
os.environ["TAVILY_API_KEY"] = config(
|
|
"TAVILY_API_KEY", cast=str, default="tvly-dev-wXXXXXX"
|
|
)
|
|
|
|
# Initialize embedding model (do this ONCE)
|
|
embedding_model = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-mpnet-base-v2")
|
|
|
|
# Initialize the DeepSeek chat model
|
|
llm = ChatDeepSeek(
|
|
model="deepseek-chat",
|
|
temperature=0,
|
|
max_tokens=None,
|
|
timeout=None,
|
|
max_retries=2,
|
|
)
|
|
|
|
# Load data from ChromaDB
|
|
db_folder = "chroma_db"
|
|
db_path = os.path.join(os.getcwd(), db_folder)
|
|
collection_name = "rag-chroma"
|
|
vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model)
|
|
retriever = vectorstore.as_retriever()
|
|
|
|
|
|
############################ LLM functions ############################
|
|
# Routing to vectorstore or web search
|
|
structured_llm_router = llm.with_structured_output(RouteQuery)
|
|
# Prompt
|
|
route_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", system_router),
|
|
("human", "{question}"),
|
|
]
|
|
)
|
|
question_router = route_prompt | structured_llm_router
|
|
|
|
### Retrieval Grader
|
|
structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
|
# Prompt
|
|
grade_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", system_retriever_grader),
|
|
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
|
|
]
|
|
)
|
|
retrieval_grader = grade_prompt | structured_llm_grader
|
|
|
|
### Generate
|
|
# Create a PromptTemplate with the given prompt
|
|
new_prompt_template = PromptTemplate(
|
|
input_variables=["context", "question"],
|
|
template=qa_prompt_template,
|
|
)
|
|
|
|
# Create a new HumanMessagePromptTemplate with the new PromptTemplate
|
|
new_human_message_prompt_template = HumanMessagePromptTemplate(
|
|
prompt=new_prompt_template
|
|
)
|
|
prompt_qa = ChatPromptTemplate.from_messages([new_human_message_prompt_template])
|
|
|
|
# Chain
|
|
rag_chain = prompt_qa | llm | StrOutputParser()
|
|
|
|
|
|
### Hallucination Grader
|
|
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
|
|
|
|
# Prompt
|
|
hallucination_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", system_hallucination_grader),
|
|
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
|
|
]
|
|
)
|
|
|
|
hallucination_grader = hallucination_prompt | structured_llm_grader
|
|
|
|
### Answer Grader
|
|
structured_llm_grader = llm.with_structured_output(GradeAnswer)
|
|
|
|
# Prompt
|
|
answer_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", system_answer_grader),
|
|
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
|
|
]
|
|
)
|
|
answer_grader = answer_prompt | structured_llm_grader
|
|
|
|
### Question Re-writer
|
|
# Prompt
|
|
re_write_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", system_question_rewriter),
|
|
(
|
|
"human",
|
|
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
|
|
),
|
|
]
|
|
)
|
|
question_rewriter = re_write_prompt | llm | StrOutputParser()
|
|
|
|
### Search
|
|
web_search_tool = TavilySearchResults(k=3)
|
|
|
|
############### Graph functions ################
|
|
|
|
def retrieve(state):
|
|
"""
|
|
Retrieve documents
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
state (dict): New key added to state, documents, that contains retrieved documents
|
|
"""
|
|
print("---RETRIEVE---")
|
|
question = state["question"]
|
|
|
|
# Retrieval
|
|
documents = retriever.invoke(question)
|
|
|
|
print(documents)
|
|
return {"documents": documents, "question": question}
|
|
|
|
|
|
def generate(state):
|
|
"""
|
|
Generate answer
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
state (dict): New key added to state, generation, that contains LLM generation
|
|
"""
|
|
print("---GENERATE---")
|
|
question = state["question"]
|
|
documents = state["documents"]
|
|
|
|
# RAG generation
|
|
generation = rag_chain.invoke({"context": documents, "question": question})
|
|
return {"documents": documents, "question": question, "generation": generation}
|
|
|
|
|
|
def grade_documents(state):
|
|
"""
|
|
Determines whether the retrieved documents are relevant to the question.
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
state (dict): Updates documents key with only filtered relevant documents
|
|
"""
|
|
|
|
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
|
|
question = state["question"]
|
|
documents = state["documents"]
|
|
|
|
# Score each doc
|
|
filtered_docs = []
|
|
for d in documents:
|
|
score = retrieval_grader.invoke(
|
|
{"question": question, "document": d.page_content}
|
|
)
|
|
grade = score.binary_score
|
|
if grade == "yes":
|
|
print("---GRADE: DOCUMENT RELEVANT---")
|
|
filtered_docs.append(d)
|
|
else:
|
|
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
|
continue
|
|
return {"documents": filtered_docs, "question": question}
|
|
|
|
|
|
def transform_query(state):
|
|
"""
|
|
Transform the query to produce a better question.
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
state (dict): Updates question key with a re-phrased question
|
|
"""
|
|
|
|
print("---TRANSFORM QUERY---")
|
|
question = state["question"]
|
|
documents = state["documents"]
|
|
|
|
# Re-write question
|
|
better_question = question_rewriter.invoke({"question": question})
|
|
return {"documents": documents, "question": better_question}
|
|
|
|
|
|
def web_search(state):
|
|
"""
|
|
Web search based on the re-phrased question.
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
state (dict): Updates documents key with appended web results
|
|
"""
|
|
|
|
print("---WEB SEARCH---")
|
|
question = state["question"]
|
|
|
|
# Web search
|
|
docs = web_search_tool.invoke({"query": question})
|
|
web_results = "\n".join([d["content"] for d in docs])
|
|
web_results = Document(page_content=web_results)
|
|
|
|
return {"documents": web_results, "question": question}
|
|
|
|
|
|
### Edges ###
|
|
def route_question(state):
|
|
"""
|
|
Route question to web search or RAG.
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
str: Next node to call
|
|
"""
|
|
|
|
print("---ROUTE QUESTION---")
|
|
question = state["question"]
|
|
source = question_router.invoke({"question": question})
|
|
if source.datasource == "web_search":
|
|
print("---ROUTE QUESTION TO WEB SEARCH---")
|
|
return "web_search"
|
|
elif source.datasource == "vectorstore":
|
|
print("---ROUTE QUESTION TO RAG---")
|
|
return "vectorstore"
|
|
|
|
|
|
def decide_to_generate(state):
|
|
"""
|
|
Determines whether to generate an answer, or re-generate a question.
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
str: Binary decision for next node to call
|
|
"""
|
|
|
|
print("---ASSESS GRADED DOCUMENTS---")
|
|
state["question"]
|
|
filtered_documents = state["documents"]
|
|
|
|
if not filtered_documents:
|
|
# All documents have been filtered check_relevance
|
|
# We will re-generate a new query
|
|
print(
|
|
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
|
|
)
|
|
return "transform_query"
|
|
else:
|
|
# We have relevant documents, so generate answer
|
|
print("---DECISION: GENERATE---")
|
|
return "generate"
|
|
|
|
|
|
def grade_generation_v_documents_and_question(state):
|
|
"""
|
|
Determines whether the generation is grounded in the document and answers question.
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
str: Decision for next node to call
|
|
"""
|
|
|
|
print("---CHECK HALLUCINATIONS---")
|
|
question = state["question"]
|
|
documents = state["documents"]
|
|
generation = state["generation"]
|
|
|
|
score = hallucination_grader.invoke(
|
|
{"documents": documents, "generation": generation}
|
|
)
|
|
grade = score.binary_score
|
|
|
|
# Check hallucination
|
|
if grade == "yes":
|
|
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
# Check question-answering
|
|
print("---GRADE GENERATION vs QUESTION---")
|
|
score = answer_grader.invoke({"question": question, "generation": generation})
|
|
grade = score.binary_score
|
|
if grade == "yes":
|
|
print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
return "useful"
|
|
else:
|
|
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
return "not useful"
|
|
else:
|
|
pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
|
return "not supported"
|
|
|
|
|
|
class GraphState(TypedDict):
|
|
"""
|
|
Represents the state of our graph.
|
|
|
|
Attributes:
|
|
question: question
|
|
generation: LLM generation
|
|
documents: list of documents
|
|
"""
|
|
|
|
question: str
|
|
generation: str
|
|
documents: List[str]
|
|
|
|
workflow = StateGraph(GraphState)
|
|
|
|
# Define the nodes
|
|
workflow.add_node("web_search", web_search) # web search
|
|
workflow.add_node("retrieve", retrieve) # retrieve
|
|
workflow.add_node("grade_documents", grade_documents) # grade documents
|
|
workflow.add_node("generate", generate) # generatae
|
|
workflow.add_node("transform_query", transform_query) # transform_query
|
|
|
|
# Build graph
|
|
workflow.add_conditional_edges(
|
|
START,
|
|
route_question,
|
|
{
|
|
"web_search": "web_search",
|
|
"vectorstore": "retrieve",
|
|
},
|
|
)
|
|
workflow.add_edge("web_search", "generate")
|
|
workflow.add_edge("retrieve", "grade_documents")
|
|
workflow.add_conditional_edges(
|
|
"grade_documents",
|
|
decide_to_generate,
|
|
{
|
|
"transform_query": "transform_query",
|
|
"generate": "generate",
|
|
},
|
|
)
|
|
workflow.add_edge("transform_query", "retrieve")
|
|
workflow.add_conditional_edges(
|
|
"generate",
|
|
grade_generation_v_documents_and_question,
|
|
{
|
|
"not supported": "generate",
|
|
"useful": END,
|
|
"not useful": "transform_query",
|
|
},
|
|
)
|
|
|
|
# Compile
|
|
app = workflow.compile()
|
|
|
|
# Initialize the connection manager
|
|
manager = ConnectionManager()
|
|
|
|
@router.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await manager.connect(websocket)
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
|
|
try:
|
|
data_json = json.loads(data)
|
|
if (
|
|
isinstance(data_json, list)
|
|
and len(data_json) > 0
|
|
and "content" in data_json[0]
|
|
):
|
|
inputs = {
|
|
"question": data_json[0]["content"]
|
|
}
|
|
async for chunk in app.astream(inputs):
|
|
await manager.send_personal_message(
|
|
json.dumps({"type": "message", "payload": chunk.get("content", str(chunk))}),
|
|
websocket,
|
|
)
|
|
# Send a final 'done' message to signal completion
|
|
await manager.send_personal_message(
|
|
json.dumps({"type": "done"}),
|
|
websocket,
|
|
)
|
|
else:
|
|
await manager.send_personal_message(
|
|
"Invalid message format", websocket
|
|
)
|
|
|
|
except json.JSONDecodeError:
|
|
await manager.broadcast("Invalid JSON message")
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
await manager.broadcast("Client disconnected")
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(websocket)
|
|
await manager.broadcast("Client disconnected")
|