From 486a79a2ccca52d20914e5d2dada48018269e4fa Mon Sep 17 00:00:00 2001 From: leehk Date: Wed, 12 Mar 2025 17:59:25 +0800 Subject: [PATCH] refactored adaptive rag --- app/llmops/config.yaml | 2 +- app/llmops/main.py | 1 + .../src/rag_adaptive_evaluation/MLproject | 6 +- .../rag_adaptive_evaluation/data_models.py | 32 +++++ .../src/rag_adaptive_evaluation/evaluators.py | 99 ++++++++++++++ .../prompts_library.py | 19 +++ .../rag_adaptive_evaluation/python_env.yml | 1 + app/llmops/src/rag_adaptive_evaluation/run.py | 129 +++++++----------- 8 files changed, 207 insertions(+), 82 deletions(-) create mode 100644 app/llmops/src/rag_adaptive_evaluation/data_models.py create mode 100644 app/llmops/src/rag_adaptive_evaluation/evaluators.py create mode 100644 app/llmops/src/rag_adaptive_evaluation/prompts_library.py diff --git a/app/llmops/config.yaml b/app/llmops/config.yaml index 37aeb2b..5452f8c 100644 --- a/app/llmops/config.yaml +++ b/app/llmops/config.yaml @@ -13,4 +13,4 @@ prompt_engineering: run_id_chromadb: None chat_model_provider: gemini query: "如何治疗乳腺癌?" - \ No newline at end of file + query_evaluation_dataset_csv_path: "../../../../data/qa_datasets.csv" \ No newline at end of file diff --git a/app/llmops/main.py b/app/llmops/main.py index 4c04a63..ac768b4 100644 --- a/app/llmops/main.py +++ b/app/llmops/main.py @@ -156,6 +156,7 @@ def go(config: DictConfig): "main", parameters={ "query": config["prompt_engineering"]["query"], + "query_evaluation_dataset_csv_path": config["prompt_engineering"]["query_evaluation_dataset_csv_path"], "input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip', "embedding_model": config["etl"]["embedding_model"], "chat_model_provider": config["prompt_engineering"]["chat_model_provider"] diff --git a/app/llmops/src/rag_adaptive_evaluation/MLproject b/app/llmops/src/rag_adaptive_evaluation/MLproject index 48c1dad..457116d 100644 --- a/app/llmops/src/rag_adaptive_evaluation/MLproject +++ b/app/llmops/src/rag_adaptive_evaluation/MLproject @@ -8,6 +8,10 @@ entry_points: query: description: Query to run type: string + + query_evaluation_dataset_csv_path: + description: query evaluation dataset csv path + type: string input_chromadb_artifact: description: Fully-qualified name for the input artifact @@ -20,10 +24,10 @@ entry_points: chat_model_provider: description: Fully-qualified name for the chat model provider type: string - command: >- python run.py --query {query} \ + --query_evaluation_dataset_csv_path {query_evaluation_dataset_csv_path} \ --input_chromadb_artifact {input_chromadb_artifact} \ --embedding_model {embedding_model} \ --chat_model_provider {chat_model_provider} \ No newline at end of file diff --git a/app/llmops/src/rag_adaptive_evaluation/data_models.py b/app/llmops/src/rag_adaptive_evaluation/data_models.py new file mode 100644 index 0000000..680cfbd --- /dev/null +++ b/app/llmops/src/rag_adaptive_evaluation/data_models.py @@ -0,0 +1,32 @@ +from typing import Literal, List +from pydantic import BaseModel, Field + + +class RouteQuery(BaseModel): + """Route a user query to the most relevant datasource.""" + + datasource: Literal["vectorstore", "web_search"] = Field( + ..., + description="Given a user question choose to route it to web search or a vectorstore.", + ) + +class GradeDocuments(BaseModel): + """Binary score for relevance check on retrieved documents.""" + + binary_score: str = Field( + description="Documents are relevant to the question, 'yes' or 'no'" + ) + +class GradeHallucinations(BaseModel): + """Binary score for hallucination present in generation answer.""" + + binary_score: str = Field( + description="Answer is grounded in the facts, 'yes' or 'no'" + ) + +class GradeAnswer(BaseModel): + """Binary score to assess answer addresses question.""" + + binary_score: str = Field( + description="Answer addresses the question, 'yes' or 'no'" + ) \ No newline at end of file diff --git a/app/llmops/src/rag_adaptive_evaluation/evaluators.py b/app/llmops/src/rag_adaptive_evaluation/evaluators.py new file mode 100644 index 0000000..17b1b90 --- /dev/null +++ b/app/llmops/src/rag_adaptive_evaluation/evaluators.py @@ -0,0 +1,99 @@ +from decouple import config +from openevals.llm import create_llm_as_judge +from openevals.prompts import ( + CORRECTNESS_PROMPT, + CONCISENESS_PROMPT, + HALLUCINATION_PROMPT +) +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_deepseek import ChatDeepSeek +from langchain_community.llms.moonshot import Moonshot + +GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str) +DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str) +MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str) + +# correctness +gemini_evaluator_correctness = create_llm_as_judge( + prompt=CORRECTNESS_PROMPT, + judge=ChatGoogleGenerativeAI( + model="gemini-1.5-flash", + google_api_key=GEMINI_API_KEY, + temperature=0.5, + ), + ) + +deepseek_evaluator_correctness = create_llm_as_judge( + prompt=CORRECTNESS_PROMPT, + judge=ChatDeepSeek( + model="deepseek-chat", + temperature=0.5, + api_key=DEEKSEEK_API_KEY + ), + ) + +moonshot_evaluator_correctness = create_llm_as_judge( + prompt=CORRECTNESS_PROMPT, + judge=Moonshot( + model="moonshot-v1-128k", + temperature=0.5, + api_key=MOONSHOT_API_KEY + ), + ) + +# conciseness +gemini_evaluator_conciseness = create_llm_as_judge( + prompt=CONCISENESS_PROMPT, + judge=ChatGoogleGenerativeAI( + model="gemini-1.5-flash", + google_api_key=GEMINI_API_KEY, + temperature=0.5, + ), + ) + +deepseek_evaluator_conciseness = create_llm_as_judge( + prompt=CONCISENESS_PROMPT, + judge=ChatDeepSeek( + model="deepseek-chat", + temperature=0.5, + api_key=DEEKSEEK_API_KEY + ), + ) + +moonshot_evaluator_conciseness = create_llm_as_judge( + prompt=CONCISENESS_PROMPT, + judge=Moonshot( + model="moonshot-v1-128k", + temperature=0.5, + api_key=MOONSHOT_API_KEY + ), + ) + +# hallucination +gemini_evaluator_hallucination = create_llm_as_judge( + prompt=HALLUCINATION_PROMPT, + judge=ChatGoogleGenerativeAI( + model="gemini-1.5-flash", + google_api_key=GEMINI_API_KEY, + temperature=0.5, + ), + ) + +deepseek_evaluator_hallucination = create_llm_as_judge( + prompt=HALLUCINATION_PROMPT, + judge=ChatDeepSeek( + model="deepseek-chat", + temperature=0.5, + api_key=DEEKSEEK_API_KEY + ), + ) + +moonshot_evaluator_hallucination = create_llm_as_judge( + prompt=HALLUCINATION_PROMPT, + judge=Moonshot( + model="moonshot-v1-128k", + temperature=0.5, + api_key=MOONSHOT_API_KEY + ), + ) + diff --git a/app/llmops/src/rag_adaptive_evaluation/prompts_library.py b/app/llmops/src/rag_adaptive_evaluation/prompts_library.py new file mode 100644 index 0000000..fcaf564 --- /dev/null +++ b/app/llmops/src/rag_adaptive_evaluation/prompts_library.py @@ -0,0 +1,19 @@ +system_router = """You are an expert at routing a user question to a vectorstore or web search. +The vectorstore contains documents related to any cancer/tumor disease. The question may be +asked in a variety of languages, and may be phrased in a variety of ways. +Use the vectorstore for questions on these topics. Otherwise, use web-search. +""" + +system_retriever_grader = """You are a grader assessing relevance of a retrieved document to a user question. \n + If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n + It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n + Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" + +system_hallucination_grader = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n + Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.""" + +system_answer_grader = """You are a grader assessing whether an answer addresses / resolves a question \n + Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.""" + +system_question_rewriter = """You a question re-writer that converts an input question to a better version that is optimized \n + for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.""" \ No newline at end of file diff --git a/app/llmops/src/rag_adaptive_evaluation/python_env.yml b/app/llmops/src/rag_adaptive_evaluation/python_env.yml index 2278969..451cdb7 100644 --- a/app/llmops/src/rag_adaptive_evaluation/python_env.yml +++ b/app/llmops/src/rag_adaptive_evaluation/python_env.yml @@ -24,6 +24,7 @@ build_dependencies: - tavily-python - langchain_huggingface - pydantic + - openevals # Dependencies required to run the project. dependencies: - mlflow==2.8.1 \ No newline at end of file diff --git a/app/llmops/src/rag_adaptive_evaluation/run.py b/app/llmops/src/rag_adaptive_evaluation/run.py index e0496c0..1fe7543 100644 --- a/app/llmops/src/rag_adaptive_evaluation/run.py +++ b/app/llmops/src/rag_adaptive_evaluation/run.py @@ -8,19 +8,43 @@ from langchain_google_genai import ChatGoogleGenerativeAI from langchain_deepseek import ChatDeepSeek from langchain_community.llms.moonshot import Moonshot from langchain_huggingface import HuggingFaceEmbeddings - from langchain_community.vectorstores.chroma import Chroma -from typing import Literal, List +from typing import List from typing_extensions import TypedDict from langchain_core.prompts import ChatPromptTemplate -from pydantic import BaseModel, Field from langchain_community.tools.tavily_search import TavilySearchResults from langchain.schema import Document from pprint import pprint from langgraph.graph import END, StateGraph, START +from langsmith import Client +from data_models import ( + RouteQuery, + GradeDocuments, + GradeHallucinations, + GradeAnswer +) +from prompts_library import ( + system_router, + system_retriever_grader, + system_hallucination_grader, + system_answer_grader, + system_question_rewriter +) + +from evaluators import ( + gemini_evaluator_correctness, + deepseek_evaluator_correctness, + moonshot_evaluator_correctness, + gemini_evaluator_conciseness, + deepseek_evaluator_conciseness, + moonshot_evaluator_conciseness, + gemini_evaluator_hallucination, + deepseek_evaluator_hallucination, + moonshot_evaluator_hallucination +) logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s") logger = logging.getLogger() @@ -98,61 +122,32 @@ def go(args): vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model) retriever = vectorstore.as_retriever() - # Data model - class RouteQuery(BaseModel): - """Route a user query to the most relevant datasource.""" - - datasource: Literal["vectorstore", "web_search"] = Field( - ..., - description="Given a user question choose to route it to web search or a vectorstore.", - ) - + ########################################## + # Routing to vectorstore or web search structured_llm_router = llm.with_structured_output(RouteQuery) - # Prompt - system = """You are an expert at routing a user question to a vectorstore or web search. - The vectorstore contains documents related to any cancer/tumor disease. The question may be - asked in a variety of languages, and may be phrased in a variety of ways. - Use the vectorstore for questions on these topics. Otherwise, use web-search. - """ route_prompt = ChatPromptTemplate.from_messages( [ - ("system", system), + ("system", system_router), ("human", "{question}"), ] ) - question_router = route_prompt | structured_llm_router - + ########################################## ### Retrieval Grader - # Data model - class GradeDocuments(BaseModel): - """Binary score for relevance check on retrieved documents.""" - - binary_score: str = Field( - description="Documents are relevant to the question, 'yes' or 'no'" - ) - structured_llm_grader = llm.with_structured_output(GradeDocuments) - # Prompt - system = """You are a grader assessing relevance of a retrieved document to a user question. \n - If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n - It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n - Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" grade_prompt = ChatPromptTemplate.from_messages( [ - ("system", system), + ("system", system_retriever_grader), ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"), ] ) - retrieval_grader = grade_prompt | structured_llm_grader - + ########################################## ### Generate - from langchain import hub from langchain_core.output_parsers import StrOutputParser @@ -167,76 +162,45 @@ def go(args): rag_chain = prompt | llm | StrOutputParser() - + ########################################## ### Hallucination Grader - - # Data model - class GradeHallucinations(BaseModel): - """Binary score for hallucination present in generation answer.""" - - binary_score: str = Field( - description="Answer is grounded in the facts, 'yes' or 'no'" - ) - - - # LLM with function call structured_llm_grader = llm.with_structured_output(GradeHallucinations) # Prompt - system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n - Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.""" hallucination_prompt = ChatPromptTemplate.from_messages( [ - ("system", system), + ("system", system_hallucination_grader), ("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"), ] ) hallucination_grader = hallucination_prompt | structured_llm_grader - + ########################################## ### Answer Grader - # Data model - class GradeAnswer(BaseModel): - """Binary score to assess answer addresses question.""" - - binary_score: str = Field( - description="Answer addresses the question, 'yes' or 'no'" - ) - - - # LLM with function call structured_llm_grader = llm.with_structured_output(GradeAnswer) # Prompt - system = """You are a grader assessing whether an answer addresses / resolves a question \n - Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.""" answer_prompt = ChatPromptTemplate.from_messages( [ - ("system", system), + ("system", system_answer_grader), ("human", "User question: \n\n {question} \n\n LLM generation: {generation}"), ] ) - answer_grader = answer_prompt | structured_llm_grader + ########################################## ### Question Re-writer - - # LLM - # Prompt - system = """You a question re-writer that converts an input question to a better version that is optimized \n - for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.""" re_write_prompt = ChatPromptTemplate.from_messages( [ - ("system", system), + ("system", system_question_rewriter), ( "human", "Here is the initial question: \n\n {question} \n Formulate an improved question.", ), ] ) - question_rewriter = re_write_prompt | llm | StrOutputParser() @@ -372,8 +336,6 @@ def go(args): ### Edges ### - - def route_question(state): """ Route question to web search or RAG. @@ -504,8 +466,6 @@ def go(args): # Compile app = workflow.compile() - - # Run inputs = { "question": args.query @@ -521,8 +481,10 @@ def go(args): # Final generation pprint(value["generation"]) + + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Chain of Thought RAG") + parser = argparse.ArgumentParser(description="Adaptive AG") parser.add_argument( "--query", @@ -531,6 +493,13 @@ if __name__ == "__main__": required=True ) + parser.add_argument( + "--query_evaluation_dataset_csv_path", + type=str, + help="Path to the query evaluation dataset", + default=None, + ) + parser.add_argument( "--input_chromadb_artifact", type=str,