refactored adaptive rag

This commit is contained in:
leehk 2025-03-12 17:59:25 +08:00
parent 8b68c60249
commit 486a79a2cc
8 changed files with 207 additions and 82 deletions

View File

@ -13,4 +13,4 @@ prompt_engineering:
run_id_chromadb: None
chat_model_provider: gemini
query: "如何治疗乳腺癌?"
query_evaluation_dataset_csv_path: "../../../../data/qa_datasets.csv"

View File

@ -156,6 +156,7 @@ def go(config: DictConfig):
"main",
parameters={
"query": config["prompt_engineering"]["query"],
"query_evaluation_dataset_csv_path": config["prompt_engineering"]["query_evaluation_dataset_csv_path"],
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
"embedding_model": config["etl"]["embedding_model"],
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]

View File

@ -8,6 +8,10 @@ entry_points:
query:
description: Query to run
type: string
query_evaluation_dataset_csv_path:
description: query evaluation dataset csv path
type: string
input_chromadb_artifact:
description: Fully-qualified name for the input artifact
@ -20,10 +24,10 @@ entry_points:
chat_model_provider:
description: Fully-qualified name for the chat model provider
type: string
command: >-
python run.py --query {query} \
--query_evaluation_dataset_csv_path {query_evaluation_dataset_csv_path} \
--input_chromadb_artifact {input_chromadb_artifact} \
--embedding_model {embedding_model} \
--chat_model_provider {chat_model_provider}

View File

@ -0,0 +1,32 @@
from typing import Literal, List
from pydantic import BaseModel, Field
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)

View File

@ -0,0 +1,99 @@
from decouple import config
from openevals.llm import create_llm_as_judge
from openevals.prompts import (
CORRECTNESS_PROMPT,
CONCISENESS_PROMPT,
HALLUCINATION_PROMPT
)
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
# correctness
gemini_evaluator_correctness = create_llm_as_judge(
prompt=CORRECTNESS_PROMPT,
judge=ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0.5,
),
)
deepseek_evaluator_correctness = create_llm_as_judge(
prompt=CORRECTNESS_PROMPT,
judge=ChatDeepSeek(
model="deepseek-chat",
temperature=0.5,
api_key=DEEKSEEK_API_KEY
),
)
moonshot_evaluator_correctness = create_llm_as_judge(
prompt=CORRECTNESS_PROMPT,
judge=Moonshot(
model="moonshot-v1-128k",
temperature=0.5,
api_key=MOONSHOT_API_KEY
),
)
# conciseness
gemini_evaluator_conciseness = create_llm_as_judge(
prompt=CONCISENESS_PROMPT,
judge=ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0.5,
),
)
deepseek_evaluator_conciseness = create_llm_as_judge(
prompt=CONCISENESS_PROMPT,
judge=ChatDeepSeek(
model="deepseek-chat",
temperature=0.5,
api_key=DEEKSEEK_API_KEY
),
)
moonshot_evaluator_conciseness = create_llm_as_judge(
prompt=CONCISENESS_PROMPT,
judge=Moonshot(
model="moonshot-v1-128k",
temperature=0.5,
api_key=MOONSHOT_API_KEY
),
)
# hallucination
gemini_evaluator_hallucination = create_llm_as_judge(
prompt=HALLUCINATION_PROMPT,
judge=ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0.5,
),
)
deepseek_evaluator_hallucination = create_llm_as_judge(
prompt=HALLUCINATION_PROMPT,
judge=ChatDeepSeek(
model="deepseek-chat",
temperature=0.5,
api_key=DEEKSEEK_API_KEY
),
)
moonshot_evaluator_hallucination = create_llm_as_judge(
prompt=HALLUCINATION_PROMPT,
judge=Moonshot(
model="moonshot-v1-128k",
temperature=0.5,
api_key=MOONSHOT_API_KEY
),
)

View File

@ -0,0 +1,19 @@
system_router = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to any cancer/tumor disease. The question may be
asked in a variety of languages, and may be phrased in a variety of ways.
Use the vectorstore for questions on these topics. Otherwise, use web-search.
"""
system_retriever_grader = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
system_hallucination_grader = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
system_answer_grader = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
system_question_rewriter = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""

View File

@ -24,6 +24,7 @@ build_dependencies:
- tavily-python
- langchain_huggingface
- pydantic
- openevals
# Dependencies required to run the project.
dependencies:
- mlflow==2.8.1

View File

@ -8,19 +8,43 @@ from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores.chroma import Chroma
from typing import Literal, List
from typing import List
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.schema import Document
from pprint import pprint
from langgraph.graph import END, StateGraph, START
from langsmith import Client
from data_models import (
RouteQuery,
GradeDocuments,
GradeHallucinations,
GradeAnswer
)
from prompts_library import (
system_router,
system_retriever_grader,
system_hallucination_grader,
system_answer_grader,
system_question_rewriter
)
from evaluators import (
gemini_evaluator_correctness,
deepseek_evaluator_correctness,
moonshot_evaluator_correctness,
gemini_evaluator_conciseness,
deepseek_evaluator_conciseness,
moonshot_evaluator_conciseness,
gemini_evaluator_hallucination,
deepseek_evaluator_hallucination,
moonshot_evaluator_hallucination
)
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
@ -98,61 +122,32 @@ def go(args):
vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model)
retriever = vectorstore.as_retriever()
# Data model
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
##########################################
# Routing to vectorstore or web search
structured_llm_router = llm.with_structured_output(RouteQuery)
# Prompt
system = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to any cancer/tumor disease. The question may be
asked in a variety of languages, and may be phrased in a variety of ways.
Use the vectorstore for questions on these topics. Otherwise, use web-search.
"""
route_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_router),
("human", "{question}"),
]
)
question_router = route_prompt | structured_llm_router
##########################################
### Retrieval Grader
# Data model
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
# Prompt
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_retriever_grader),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
##########################################
### Generate
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
@ -167,76 +162,45 @@ def go(args):
rag_chain = prompt | llm | StrOutputParser()
##########################################
### Hallucination Grader
# Data model
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
# LLM with function call
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
# Prompt
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_hallucination_grader),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
]
)
hallucination_grader = hallucination_prompt | structured_llm_grader
##########################################
### Answer Grader
# Data model
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)
# LLM with function call
structured_llm_grader = llm.with_structured_output(GradeAnswer)
# Prompt
system = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_answer_grader),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grader
##########################################
### Question Re-writer
# LLM
# Prompt
system = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_question_rewriter),
(
"human",
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
@ -372,8 +336,6 @@ def go(args):
### Edges ###
def route_question(state):
"""
Route question to web search or RAG.
@ -504,8 +466,6 @@ def go(args):
# Compile
app = workflow.compile()
# Run
inputs = {
"question": args.query
@ -521,8 +481,10 @@ def go(args):
# Final generation
pprint(value["generation"])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
parser = argparse.ArgumentParser(description="Adaptive AG")
parser.add_argument(
"--query",
@ -531,6 +493,13 @@ if __name__ == "__main__":
required=True
)
parser.add_argument(
"--query_evaluation_dataset_csv_path",
type=str,
help="Path to the query evaluation dataset",
default=None,
)
parser.add_argument(
"--input_chromadb_artifact",
type=str,