refactored adaptive rag

This commit is contained in:
leehk 2025-03-12 17:59:25 +08:00
parent 8b68c60249
commit 486a79a2cc
8 changed files with 207 additions and 82 deletions

View File

@ -13,4 +13,4 @@ prompt_engineering:
run_id_chromadb: None run_id_chromadb: None
chat_model_provider: gemini chat_model_provider: gemini
query: "如何治疗乳腺癌?" query: "如何治疗乳腺癌?"
query_evaluation_dataset_csv_path: "../../../../data/qa_datasets.csv"

View File

@ -156,6 +156,7 @@ def go(config: DictConfig):
"main", "main",
parameters={ parameters={
"query": config["prompt_engineering"]["query"], "query": config["prompt_engineering"]["query"],
"query_evaluation_dataset_csv_path": config["prompt_engineering"]["query_evaluation_dataset_csv_path"],
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip', "input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
"embedding_model": config["etl"]["embedding_model"], "embedding_model": config["etl"]["embedding_model"],
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"] "chat_model_provider": config["prompt_engineering"]["chat_model_provider"]

View File

@ -9,6 +9,10 @@ entry_points:
description: Query to run description: Query to run
type: string type: string
query_evaluation_dataset_csv_path:
description: query evaluation dataset csv path
type: string
input_chromadb_artifact: input_chromadb_artifact:
description: Fully-qualified name for the input artifact description: Fully-qualified name for the input artifact
type: string type: string
@ -21,9 +25,9 @@ entry_points:
description: Fully-qualified name for the chat model provider description: Fully-qualified name for the chat model provider
type: string type: string
command: >- command: >-
python run.py --query {query} \ python run.py --query {query} \
--query_evaluation_dataset_csv_path {query_evaluation_dataset_csv_path} \
--input_chromadb_artifact {input_chromadb_artifact} \ --input_chromadb_artifact {input_chromadb_artifact} \
--embedding_model {embedding_model} \ --embedding_model {embedding_model} \
--chat_model_provider {chat_model_provider} --chat_model_provider {chat_model_provider}

View File

@ -0,0 +1,32 @@
from typing import Literal, List
from pydantic import BaseModel, Field
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)

View File

@ -0,0 +1,99 @@
from decouple import config
from openevals.llm import create_llm_as_judge
from openevals.prompts import (
CORRECTNESS_PROMPT,
CONCISENESS_PROMPT,
HALLUCINATION_PROMPT
)
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
# correctness
gemini_evaluator_correctness = create_llm_as_judge(
prompt=CORRECTNESS_PROMPT,
judge=ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0.5,
),
)
deepseek_evaluator_correctness = create_llm_as_judge(
prompt=CORRECTNESS_PROMPT,
judge=ChatDeepSeek(
model="deepseek-chat",
temperature=0.5,
api_key=DEEKSEEK_API_KEY
),
)
moonshot_evaluator_correctness = create_llm_as_judge(
prompt=CORRECTNESS_PROMPT,
judge=Moonshot(
model="moonshot-v1-128k",
temperature=0.5,
api_key=MOONSHOT_API_KEY
),
)
# conciseness
gemini_evaluator_conciseness = create_llm_as_judge(
prompt=CONCISENESS_PROMPT,
judge=ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0.5,
),
)
deepseek_evaluator_conciseness = create_llm_as_judge(
prompt=CONCISENESS_PROMPT,
judge=ChatDeepSeek(
model="deepseek-chat",
temperature=0.5,
api_key=DEEKSEEK_API_KEY
),
)
moonshot_evaluator_conciseness = create_llm_as_judge(
prompt=CONCISENESS_PROMPT,
judge=Moonshot(
model="moonshot-v1-128k",
temperature=0.5,
api_key=MOONSHOT_API_KEY
),
)
# hallucination
gemini_evaluator_hallucination = create_llm_as_judge(
prompt=HALLUCINATION_PROMPT,
judge=ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0.5,
),
)
deepseek_evaluator_hallucination = create_llm_as_judge(
prompt=HALLUCINATION_PROMPT,
judge=ChatDeepSeek(
model="deepseek-chat",
temperature=0.5,
api_key=DEEKSEEK_API_KEY
),
)
moonshot_evaluator_hallucination = create_llm_as_judge(
prompt=HALLUCINATION_PROMPT,
judge=Moonshot(
model="moonshot-v1-128k",
temperature=0.5,
api_key=MOONSHOT_API_KEY
),
)

View File

@ -0,0 +1,19 @@
system_router = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to any cancer/tumor disease. The question may be
asked in a variety of languages, and may be phrased in a variety of ways.
Use the vectorstore for questions on these topics. Otherwise, use web-search.
"""
system_retriever_grader = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
system_hallucination_grader = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
system_answer_grader = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
system_question_rewriter = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""

View File

@ -24,6 +24,7 @@ build_dependencies:
- tavily-python - tavily-python
- langchain_huggingface - langchain_huggingface
- pydantic - pydantic
- openevals
# Dependencies required to run the project. # Dependencies required to run the project.
dependencies: dependencies:
- mlflow==2.8.1 - mlflow==2.8.1

View File

@ -8,19 +8,43 @@ from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot from langchain_community.llms.moonshot import Moonshot
from langchain_huggingface import HuggingFaceEmbeddings from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores.chroma import Chroma from langchain_community.vectorstores.chroma import Chroma
from typing import Literal, List from typing import List
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.schema import Document from langchain.schema import Document
from pprint import pprint from pprint import pprint
from langgraph.graph import END, StateGraph, START from langgraph.graph import END, StateGraph, START
from langsmith import Client
from data_models import (
RouteQuery,
GradeDocuments,
GradeHallucinations,
GradeAnswer
)
from prompts_library import (
system_router,
system_retriever_grader,
system_hallucination_grader,
system_answer_grader,
system_question_rewriter
)
from evaluators import (
gemini_evaluator_correctness,
deepseek_evaluator_correctness,
moonshot_evaluator_correctness,
gemini_evaluator_conciseness,
deepseek_evaluator_conciseness,
moonshot_evaluator_conciseness,
gemini_evaluator_hallucination,
deepseek_evaluator_hallucination,
moonshot_evaluator_hallucination
)
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger() logger = logging.getLogger()
@ -98,61 +122,32 @@ def go(args):
vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model) vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model)
retriever = vectorstore.as_retriever() retriever = vectorstore.as_retriever()
# Data model ##########################################
class RouteQuery(BaseModel): # Routing to vectorstore or web search
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
structured_llm_router = llm.with_structured_output(RouteQuery) structured_llm_router = llm.with_structured_output(RouteQuery)
# Prompt # Prompt
system = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to any cancer/tumor disease. The question may be
asked in a variety of languages, and may be phrased in a variety of ways.
Use the vectorstore for questions on these topics. Otherwise, use web-search.
"""
route_prompt = ChatPromptTemplate.from_messages( route_prompt = ChatPromptTemplate.from_messages(
[ [
("system", system), ("system", system_router),
("human", "{question}"), ("human", "{question}"),
] ]
) )
question_router = route_prompt | structured_llm_router question_router = route_prompt | structured_llm_router
##########################################
### Retrieval Grader ### Retrieval Grader
# Data model
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
structured_llm_grader = llm.with_structured_output(GradeDocuments) structured_llm_grader = llm.with_structured_output(GradeDocuments)
# Prompt # Prompt
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages( grade_prompt = ChatPromptTemplate.from_messages(
[ [
("system", system), ("system", system_retriever_grader),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"), ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
] ]
) )
retrieval_grader = grade_prompt | structured_llm_grader retrieval_grader = grade_prompt | structured_llm_grader
##########################################
### Generate ### Generate
from langchain import hub from langchain import hub
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
@ -167,76 +162,45 @@ def go(args):
rag_chain = prompt | llm | StrOutputParser() rag_chain = prompt | llm | StrOutputParser()
##########################################
### Hallucination Grader ### Hallucination Grader
# Data model
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
# LLM with function call
structured_llm_grader = llm.with_structured_output(GradeHallucinations) structured_llm_grader = llm.with_structured_output(GradeHallucinations)
# Prompt # Prompt
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages( hallucination_prompt = ChatPromptTemplate.from_messages(
[ [
("system", system), ("system", system_hallucination_grader),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"), ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
] ]
) )
hallucination_grader = hallucination_prompt | structured_llm_grader hallucination_grader = hallucination_prompt | structured_llm_grader
##########################################
### Answer Grader ### Answer Grader
# Data model
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)
# LLM with function call
structured_llm_grader = llm.with_structured_output(GradeAnswer) structured_llm_grader = llm.with_structured_output(GradeAnswer)
# Prompt # Prompt
system = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages( answer_prompt = ChatPromptTemplate.from_messages(
[ [
("system", system), ("system", system_answer_grader),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"), ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
] ]
) )
answer_grader = answer_prompt | structured_llm_grader answer_grader = answer_prompt | structured_llm_grader
##########################################
### Question Re-writer ### Question Re-writer
# LLM
# Prompt # Prompt
system = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages( re_write_prompt = ChatPromptTemplate.from_messages(
[ [
("system", system), ("system", system_question_rewriter),
( (
"human", "human",
"Here is the initial question: \n\n {question} \n Formulate an improved question.", "Here is the initial question: \n\n {question} \n Formulate an improved question.",
), ),
] ]
) )
question_rewriter = re_write_prompt | llm | StrOutputParser() question_rewriter = re_write_prompt | llm | StrOutputParser()
@ -372,8 +336,6 @@ def go(args):
### Edges ### ### Edges ###
def route_question(state): def route_question(state):
""" """
Route question to web search or RAG. Route question to web search or RAG.
@ -504,8 +466,6 @@ def go(args):
# Compile # Compile
app = workflow.compile() app = workflow.compile()
# Run # Run
inputs = { inputs = {
"question": args.query "question": args.query
@ -521,8 +481,10 @@ def go(args):
# Final generation # Final generation
pprint(value["generation"]) pprint(value["generation"])
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chain of Thought RAG") parser = argparse.ArgumentParser(description="Adaptive AG")
parser.add_argument( parser.add_argument(
"--query", "--query",
@ -531,6 +493,13 @@ if __name__ == "__main__":
required=True required=True
) )
parser.add_argument(
"--query_evaluation_dataset_csv_path",
type=str,
help="Path to the query evaluation dataset",
default=None,
)
parser.add_argument( parser.add_argument(
"--input_chromadb_artifact", "--input_chromadb_artifact",
type=str, type=str,