mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-19 13:23:23 +08:00
Merge pull request #19 from aimingmed/feature/langsmith-evaluation
Feature/langsmith evaluation
This commit is contained in:
commit
465c24546d
3
.gitignore
vendored
3
.gitignore
vendored
@ -208,4 +208,5 @@ data/*
|
||||
**/*.zip
|
||||
**/llm-examples/*
|
||||
**/*.ipynb_checkpoints
|
||||
**/*.ipynb
|
||||
**/*.ipynb
|
||||
**/transformer_model/*
|
||||
@ -9,8 +9,16 @@ etl:
|
||||
path_document_folder: "../../../../data"
|
||||
run_id_documents: None
|
||||
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
||||
prompt_engineering:
|
||||
rag:
|
||||
run_id_chromadb: None
|
||||
chat_model_provider: gemini
|
||||
chat_model_provider: deepseek
|
||||
testing:
|
||||
query: "如何治疗乳腺癌?"
|
||||
|
||||
evaluation:
|
||||
evaluation_dataset_csv_path: "../../../../data/qa_dataset_20240321a.csv"
|
||||
evaluation_dataset_column_question: question
|
||||
evaluation_dataset_column_answer: answer
|
||||
ls_chat_model_provider:
|
||||
- gemini
|
||||
- deepseek
|
||||
- moonshot
|
||||
|
||||
@ -9,7 +9,7 @@ _steps = [
|
||||
"etl_chromadb_pdf",
|
||||
"etl_chromadb_scanned_pdf", # the performance for scanned pdf may not be good
|
||||
"rag_cot_evaluation",
|
||||
"adaptive_rag_evaluation",
|
||||
"rag_adaptive_evaluation",
|
||||
"test_rag_cot"
|
||||
]
|
||||
|
||||
@ -104,7 +104,7 @@ def go(config: DictConfig):
|
||||
)
|
||||
if "rag_cot_evaluation" in active_steps:
|
||||
|
||||
if config["prompt_engineering"]["run_id_chromadb"] == "None":
|
||||
if config["rag"]["run_id_chromadb"] == "None":
|
||||
# Look for run_id that has artifact logged as documents
|
||||
run_id = None
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
@ -119,22 +119,22 @@ def go(config: DictConfig):
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["prompt_engineering"]["run_id_chromadb"]
|
||||
run_id = config["rag"]["run_id_chromadb"]
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation"),
|
||||
"main",
|
||||
parameters={
|
||||
"query": config["prompt_engineering"]["query"],
|
||||
"query": config["testing"]["query"],
|
||||
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
|
||||
"embedding_model": config["etl"]["embedding_model"],
|
||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||
"chat_model_provider": config["rag"]["chat_model_provider"]
|
||||
},
|
||||
)
|
||||
|
||||
if "adaptive_rag_evaluation" in active_steps:
|
||||
if "rag_adaptive_evaluation" in active_steps:
|
||||
|
||||
if config["prompt_engineering"]["run_id_chromadb"] == "None":
|
||||
if config["rag"]["run_id_chromadb"] == "None":
|
||||
# Look for run_id that has artifact logged as documents
|
||||
run_id = None
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
@ -149,16 +149,20 @@ def go(config: DictConfig):
|
||||
if run_id is None:
|
||||
raise ValueError("No run_id found with artifact logged as documents")
|
||||
else:
|
||||
run_id = config["prompt_engineering"]["run_id_chromadb"]
|
||||
run_id = config["rag"]["run_id_chromadb"]
|
||||
|
||||
_ = mlflow.run(
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "adaptive_rag_evaluation"),
|
||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_adaptive_evaluation"),
|
||||
"main",
|
||||
parameters={
|
||||
"query": config["prompt_engineering"]["query"],
|
||||
"query": config["testing"]["query"],
|
||||
"evaluation_dataset_csv_path": config["evaluation"]["evaluation_dataset_csv_path"],
|
||||
"evaluation_dataset_column_question": config["evaluation"]["evaluation_dataset_column_question"],
|
||||
"evaluation_dataset_column_answer": config["evaluation"]["evaluation_dataset_column_answer"],
|
||||
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
|
||||
"embedding_model": config["etl"]["embedding_model"],
|
||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||
"chat_model_provider": config["rag"]["chat_model_provider"],
|
||||
"ls_chat_model_evaluator": ','.join(config["evaluation"]["ls_chat_model_provider"]) if config["evaluation"]["ls_chat_model_provider"] is not None else 'None',
|
||||
},
|
||||
)
|
||||
|
||||
@ -168,10 +172,10 @@ def go(config: DictConfig):
|
||||
os.path.join(hydra.utils.get_original_cwd(), "components", "test_rag_cot"),
|
||||
"main",
|
||||
parameters={
|
||||
"query": config["prompt_engineering"]["query"],
|
||||
"query": config["testing"]["query"],
|
||||
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation", "chroma_db"),
|
||||
"embedding_model": config["etl"]["embedding_model"],
|
||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
||||
"chat_model_provider": config["rag"]["chat_model_provider"]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -1,29 +0,0 @@
|
||||
name: adaptive_rag_evaluation
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
main:
|
||||
parameters:
|
||||
|
||||
query:
|
||||
description: Query to run
|
||||
type: string
|
||||
|
||||
input_chromadb_artifact:
|
||||
description: Fully-qualified name for the input artifact
|
||||
type: string
|
||||
|
||||
embedding_model:
|
||||
description: Fully-qualified name for the embedding model
|
||||
type: string
|
||||
|
||||
chat_model_provider:
|
||||
description: Fully-qualified name for the chat model provider
|
||||
type: string
|
||||
|
||||
|
||||
command: >-
|
||||
python run.py --query {query} \
|
||||
--input_chromadb_artifact {input_chromadb_artifact} \
|
||||
--embedding_model {embedding_model} \
|
||||
--chat_model_provider {chat_model_provider}
|
||||
@ -105,7 +105,7 @@ def go(args):
|
||||
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=1000, chunk_overlap=500
|
||||
chunk_size=15000, chunk_overlap=7500
|
||||
)
|
||||
|
||||
ls_docs = []
|
||||
@ -113,7 +113,7 @@ def go(args):
|
||||
for file in files:
|
||||
if file.endswith(".pdf"):
|
||||
read_text = extract_chinese_text_from_pdf(os.path.join(root, file))
|
||||
document = Document(metadata={"file": file}, page_content=read_text)
|
||||
document = Document(metadata={"file": f"{documents_folder}/{file}"}, page_content=read_text)
|
||||
ls_docs.append(document)
|
||||
|
||||
doc_splits = text_splitter.split_documents(ls_docs)
|
||||
@ -138,7 +138,7 @@ def go(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||
parser = argparse.ArgumentParser(description="ETL for ChromaDB with readable PDF")
|
||||
|
||||
parser.add_argument(
|
||||
"--input_artifact",
|
||||
|
||||
49
app/llmops/src/rag_adaptive_evaluation/MLproject
Normal file
49
app/llmops/src/rag_adaptive_evaluation/MLproject
Normal file
@ -0,0 +1,49 @@
|
||||
name: rag_adaptive_evaluation
|
||||
python_env: python_env.yml
|
||||
|
||||
entry_points:
|
||||
main:
|
||||
parameters:
|
||||
|
||||
query:
|
||||
description: Query to run
|
||||
type: string
|
||||
|
||||
evaluation_dataset_csv_path:
|
||||
description: query evaluation dataset csv path
|
||||
type: string
|
||||
|
||||
evaluation_dataset_column_question:
|
||||
description: query evaluation dataset column question
|
||||
type: string
|
||||
|
||||
evaluation_dataset_column_answer:
|
||||
description: query evaluation dataset column groundtruth
|
||||
type: string
|
||||
|
||||
input_chromadb_artifact:
|
||||
description: Fully-qualified name for the input artifact
|
||||
type: string
|
||||
|
||||
embedding_model:
|
||||
description: Fully-qualified name for the embedding model
|
||||
type: string
|
||||
|
||||
chat_model_provider:
|
||||
description: Fully-qualified name for the chat model provider
|
||||
type: string
|
||||
|
||||
ls_chat_model_evaluator:
|
||||
description: list of chat model providers for evaluation
|
||||
type: string
|
||||
|
||||
|
||||
command: >-
|
||||
python run.py --query {query} \
|
||||
--evaluation_dataset_csv_path {evaluation_dataset_csv_path} \
|
||||
--evaluation_dataset_column_question {evaluation_dataset_column_question} \
|
||||
--evaluation_dataset_column_answer {evaluation_dataset_column_answer} \
|
||||
--input_chromadb_artifact {input_chromadb_artifact} \
|
||||
--embedding_model {embedding_model} \
|
||||
--chat_model_provider {chat_model_provider} \
|
||||
--ls_chat_model_evaluator {ls_chat_model_evaluator}
|
||||
32
app/llmops/src/rag_adaptive_evaluation/data_models.py
Normal file
32
app/llmops/src/rag_adaptive_evaluation/data_models.py
Normal file
@ -0,0 +1,32 @@
|
||||
from typing import Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RouteQuery(BaseModel):
|
||||
"""Route a user query to the most relevant datasource."""
|
||||
|
||||
datasource: Literal["vectorstore", "web_search"] = Field(
|
||||
...,
|
||||
description="Given a user question choose to route it to web search or a vectorstore.",
|
||||
)
|
||||
|
||||
class GradeDocuments(BaseModel):
|
||||
"""Binary score for relevance check on retrieved documents."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Documents are relevant to the question, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
class GradeHallucinations(BaseModel):
|
||||
"""Binary score for hallucination present in generation answer."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Answer is grounded in the facts, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
class GradeAnswer(BaseModel):
|
||||
"""Binary score to assess answer addresses question."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Answer addresses the question, 'yes' or 'no'"
|
||||
)
|
||||
141
app/llmops/src/rag_adaptive_evaluation/evaluators.py
Normal file
141
app/llmops/src/rag_adaptive_evaluation/evaluators.py
Normal file
@ -0,0 +1,141 @@
|
||||
import os
|
||||
from decouple import config
|
||||
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from prompts_library import CORRECTNESS_PROMPT, FAITHFULNESS_PROMPT
|
||||
|
||||
os.environ["GOOGLE_API_KEY"] = config("GOOGLE_API_KEY", cast=str)
|
||||
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
|
||||
os.environ["MOONSHOT_API_KEY"] = config("MOONSHOT_API_KEY", cast=str)
|
||||
|
||||
|
||||
# Define output schema for the evaluation
|
||||
class CorrectnessGrade(BaseModel):
|
||||
score: int = Field(description="Numerical score (1-5) indicating the correctness of the response.")
|
||||
|
||||
class FaithfulnessGrade(BaseModel):
|
||||
score: int = Field(description="Numerical score (1-5) indicating the faithfulness of the response.")
|
||||
|
||||
|
||||
|
||||
# Evaluators
|
||||
def gemini_evaluator_correctness(outputs: dict, reference_outputs: dict) -> CorrectnessGrade:
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": CORRECTNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Ground Truth answer: {reference_outputs["answer"]};
|
||||
Student's Answer: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
return CorrectnessGrade(score=int(response.content)).score
|
||||
|
||||
|
||||
def deepseek_evaluator_correctness(outputs: dict, reference_outputs: dict) -> CorrectnessGrade:
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": CORRECTNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Ground Truth answer: {reference_outputs["answer"]};
|
||||
Student's Answer: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
return CorrectnessGrade(score=int(response.content)).score
|
||||
|
||||
|
||||
def moonshot_evaluator_correctness(outputs: dict, reference_outputs: dict) -> CorrectnessGrade:
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": CORRECTNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Ground Truth answer: {reference_outputs["answer"]};
|
||||
Student's Answer: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
try:
|
||||
return CorrectnessGrade(score=int(response)).score
|
||||
except ValueError:
|
||||
score_str = response.split(":")[1].strip()
|
||||
return CorrectnessGrade(score=int(score_str)).score
|
||||
|
||||
|
||||
def gemini_evaluator_faithfulness(outputs: dict, reference_outputs: dict) -> FaithfulnessGrade:
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-pro",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": FAITHFULNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Context: {reference_outputs["answer"]};
|
||||
Output: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
return FaithfulnessGrade(score=int(response.content)).score
|
||||
|
||||
|
||||
def deepseek_evaluator_faithfulness(outputs: dict, reference_outputs: dict) -> FaithfulnessGrade:
|
||||
llm = ChatDeepSeek(
|
||||
model="deepseek-chat",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": FAITHFULNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Context: {reference_outputs["answer"]};
|
||||
Output: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
return FaithfulnessGrade(score=int(response.content)).score
|
||||
|
||||
|
||||
def moonshot_evaluator_faithfulness(outputs: dict, reference_outputs: dict) -> FaithfulnessGrade:
|
||||
llm = Moonshot(
|
||||
model="moonshot-v1-128k",
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": FAITHFULNESS_PROMPT},
|
||||
{"role": "user", "content": f"""Context: {reference_outputs["answer"]};
|
||||
Output: {outputs['response']}
|
||||
"""}
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
|
||||
try:
|
||||
return FaithfulnessGrade(score=int(response)).score
|
||||
except ValueError:
|
||||
score_str = response.split(":")[1].strip()
|
||||
return FaithfulnessGrade(score=int(score_str)).score
|
||||
|
||||
78
app/llmops/src/rag_adaptive_evaluation/prompts_library.py
Normal file
78
app/llmops/src/rag_adaptive_evaluation/prompts_library.py
Normal file
@ -0,0 +1,78 @@
|
||||
system_router = """You are an expert at routing a user question to a vectorstore or web search.
|
||||
The vectorstore contains documents related to any cancer/tumor disease. The question may be
|
||||
asked in a variety of languages, and may be phrased in a variety of ways.
|
||||
Use the vectorstore for questions on these topics. Otherwise, use web-search.
|
||||
"""
|
||||
|
||||
system_retriever_grader = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
||||
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
||||
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
|
||||
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
||||
|
||||
system_hallucination_grader = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
|
||||
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
|
||||
|
||||
system_answer_grader = """You are a grader assessing whether an answer addresses / resolves a question \n
|
||||
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
|
||||
|
||||
system_question_rewriter = """You a question re-writer that converts an input question to a better version that is optimized \n
|
||||
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
|
||||
|
||||
|
||||
# Evaluation
|
||||
CORRECTNESS_PROMPT = """You are an impartial judge. Evaluate Student Answer against Ground Truth for conceptual similarity and correctness.
|
||||
You may also be given additional information that was used by the model to generate the output.
|
||||
|
||||
Your task is to determine a numerical score called correctness based on the Student Answer and Ground Truth.
|
||||
A definition of correctness and a grading rubric are provided below.
|
||||
You must use the grading rubric to determine your score.
|
||||
|
||||
Metric definition:
|
||||
Correctness assesses the degree to which a provided Student Answer aligns with factual accuracy, completeness, logical
|
||||
consistency, and precise terminology of the Ground Truth. It evaluates the intrinsic validity of the Student Answer , independent of any
|
||||
external context. A higher score indicates a higher adherence to factual accuracy, completeness, logical consistency,
|
||||
and precise terminology of the Ground Truth.
|
||||
|
||||
Grading rubric:
|
||||
Correctness: Below are the details for different scores:
|
||||
- 1: Major factual errors, highly incomplete, illogical, and uses incorrect terminology.
|
||||
- 2: Significant factual errors, incomplete, noticeable logical flaws, and frequent terminology errors.
|
||||
- 3: Minor factual errors, somewhat incomplete, minor logical inconsistencies, and occasional terminology errors.
|
||||
- 4: Few to no factual errors, mostly complete, strong logical consistency, and accurate terminology.
|
||||
- 5: Accurate, complete, logically consistent, and uses precise terminology.
|
||||
|
||||
Reminder:
|
||||
- Carefully read the Student Answer and Ground Truth
|
||||
- Check for factual accuracy and completeness of Student Answer compared to the Ground Truth
|
||||
- Focus on correctness of information rather than style or verbosity
|
||||
- The goal is to evaluate factual correctness and completeness of the Student Answer.
|
||||
- Please provide your answer score only with the numerical number between 1 and 5. No score: or other text is allowed.
|
||||
|
||||
"""
|
||||
|
||||
FAITHFULNESS_PROMPT = """You are an impartial judge. Evaluate output against context for faithfulness.
|
||||
You may also be given additional information that was used by the model to generate the Output.
|
||||
|
||||
Your task is to determine a numerical score called faithfulness based on the output and context.
|
||||
A definition of faithfulness and a grading rubric are provided below.
|
||||
You must use the grading rubric to determine your score.
|
||||
|
||||
Metric definition:
|
||||
Faithfulness is only evaluated with the provided output and context. Faithfulness assesses how much of the
|
||||
provided output is factually consistent with the provided context. A higher score indicates that a higher proportion of
|
||||
claims present in the output can be derived from the provided context. Faithfulness does not consider how much extra
|
||||
information from the context is not present in the output.
|
||||
|
||||
Grading rubric:
|
||||
Faithfulness: Below are the details for different scores:
|
||||
- Score 1: None of the claims in the output can be inferred from the provided context.
|
||||
- Score 2: Some of the claims in the output can be inferred from the provided context, but the majority of the output is missing from, inconsistent with, or contradictory to the provided context.
|
||||
- Score 3: Half or more of the claims in the output can be inferred from the provided context.
|
||||
- Score 4: Most of the claims in the output can be inferred from the provided context, with very little information that is not directly supported by the provided context.
|
||||
- Score 5: All of the claims in the output are directly supported by the provided context, demonstrating high faithfulness to the provided context.
|
||||
|
||||
Reminder:
|
||||
- Carefully read the output and context
|
||||
- Focus on the information instead of the writing style or verbosity.
|
||||
- Please provide your answer score only with the numerical number between 1 and 5, according to the grading rubric above. No score: or other text is allowed.
|
||||
"""
|
||||
@ -3,41 +3,61 @@ import logging
|
||||
import argparse
|
||||
import mlflow
|
||||
import shutil
|
||||
import langsmith
|
||||
|
||||
from decouple import config
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_community.llms.moonshot import Moonshot
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
|
||||
from typing import Literal, List
|
||||
|
||||
from typing import List
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_community.tools.tavily_search import TavilySearchResults
|
||||
from langchain.schema import Document
|
||||
from pprint import pprint
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
from langsmith import Client
|
||||
|
||||
from data_models import (
|
||||
RouteQuery,
|
||||
GradeDocuments,
|
||||
GradeHallucinations,
|
||||
GradeAnswer
|
||||
)
|
||||
from prompts_library import (
|
||||
system_router,
|
||||
system_retriever_grader,
|
||||
system_hallucination_grader,
|
||||
system_answer_grader,
|
||||
system_question_rewriter
|
||||
)
|
||||
|
||||
from evaluators import (
|
||||
gemini_evaluator_correctness,
|
||||
deepseek_evaluator_correctness,
|
||||
moonshot_evaluator_correctness,
|
||||
gemini_evaluator_faithfulness,
|
||||
deepseek_evaluator_faithfulness,
|
||||
moonshot_evaluator_faithfulness
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
||||
TAVILY_API_KEY = config("TAVILY_API_KEY", cast=str)
|
||||
LANGSMITH_API_KEY = config("LANGSMITH_API_KEY", cast=str)
|
||||
LANGSMITH_TRACING = config("LANGSMITH_TRACING", cast=str)
|
||||
LANGSMITH_PROJECT = config("LANGSMITH_PROJECT", cast=str)
|
||||
os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY
|
||||
os.environ["GOOGLE_API_KEY"] = config("GOOGLE_API_KEY", cast=str)
|
||||
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
|
||||
os.environ["MOONSHOT_API_KEY"] = config("MOONSHOT_API_KEY", cast=str)
|
||||
os.environ["TAVILY_API_KEY"] = config("TAVILY_API_KEY", cast=str)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["LANGSMITH_API_KEY"] = LANGSMITH_API_KEY
|
||||
os.environ["LANGSMITH_TRACING"] = LANGSMITH_TRACING
|
||||
os.environ["LANGSMITH_API_KEY"] = config("LANGSMITH_API_KEY", cast=str)
|
||||
os.environ["LANGSMITH_TRACING"] = config("LANGSMITH_TRACING", cast=str)
|
||||
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||
os.environ["LANGSMITH_PROJECT"] = LANGSMITH_PROJECT
|
||||
os.environ["LANGSMITH_PROJECT"] = config("LANGSMITH_PROJECT", cast=str)
|
||||
|
||||
def go(args):
|
||||
|
||||
@ -71,12 +91,10 @@ def go(args):
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=DEEKSEEK_API_KEY
|
||||
)
|
||||
elif args.chat_model_provider == 'gemini':
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
google_api_key=GEMINI_API_KEY,
|
||||
temperature=0,
|
||||
max_retries=3,
|
||||
streaming=True
|
||||
@ -88,7 +106,6 @@ def go(args):
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=MOONSHOT_API_KEY
|
||||
)
|
||||
|
||||
# Load data from ChromaDB
|
||||
@ -98,61 +115,32 @@ def go(args):
|
||||
vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
# Data model
|
||||
class RouteQuery(BaseModel):
|
||||
"""Route a user query to the most relevant datasource."""
|
||||
|
||||
datasource: Literal["vectorstore", "web_search"] = Field(
|
||||
...,
|
||||
description="Given a user question choose to route it to web search or a vectorstore.",
|
||||
)
|
||||
|
||||
##########################################
|
||||
# Routing to vectorstore or web search
|
||||
structured_llm_router = llm.with_structured_output(RouteQuery)
|
||||
|
||||
# Prompt
|
||||
system = """You are an expert at routing a user question to a vectorstore or web search.
|
||||
The vectorstore contains documents related to any cancer/tumor disease. The question may be
|
||||
asked in a variety of languages, and may be phrased in a variety of ways.
|
||||
Use the vectorstore for questions on these topics. Otherwise, use web-search.
|
||||
"""
|
||||
route_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system),
|
||||
("system", system_router),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
|
||||
question_router = route_prompt | structured_llm_router
|
||||
|
||||
|
||||
##########################################
|
||||
### Retrieval Grader
|
||||
# Data model
|
||||
class GradeDocuments(BaseModel):
|
||||
"""Binary score for relevance check on retrieved documents."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Documents are relevant to the question, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
||||
|
||||
# Prompt
|
||||
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
|
||||
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
||||
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
|
||||
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
||||
grade_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system),
|
||||
("system", system_retriever_grader),
|
||||
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
|
||||
]
|
||||
)
|
||||
|
||||
retrieval_grader = grade_prompt | structured_llm_grader
|
||||
|
||||
|
||||
##########################################
|
||||
### Generate
|
||||
|
||||
from langchain import hub
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
@ -167,76 +155,45 @@ def go(args):
|
||||
rag_chain = prompt | llm | StrOutputParser()
|
||||
|
||||
|
||||
|
||||
##########################################
|
||||
### Hallucination Grader
|
||||
|
||||
# Data model
|
||||
class GradeHallucinations(BaseModel):
|
||||
"""Binary score for hallucination present in generation answer."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Answer is grounded in the facts, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
|
||||
# LLM with function call
|
||||
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
|
||||
|
||||
# Prompt
|
||||
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
|
||||
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
|
||||
hallucination_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system),
|
||||
("system", system_hallucination_grader),
|
||||
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
|
||||
]
|
||||
)
|
||||
|
||||
hallucination_grader = hallucination_prompt | structured_llm_grader
|
||||
|
||||
|
||||
##########################################
|
||||
### Answer Grader
|
||||
# Data model
|
||||
class GradeAnswer(BaseModel):
|
||||
"""Binary score to assess answer addresses question."""
|
||||
|
||||
binary_score: str = Field(
|
||||
description="Answer addresses the question, 'yes' or 'no'"
|
||||
)
|
||||
|
||||
|
||||
# LLM with function call
|
||||
structured_llm_grader = llm.with_structured_output(GradeAnswer)
|
||||
|
||||
# Prompt
|
||||
system = """You are a grader assessing whether an answer addresses / resolves a question \n
|
||||
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
|
||||
answer_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system),
|
||||
("system", system_answer_grader),
|
||||
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
|
||||
]
|
||||
)
|
||||
|
||||
answer_grader = answer_prompt | structured_llm_grader
|
||||
|
||||
##########################################
|
||||
### Question Re-writer
|
||||
|
||||
# LLM
|
||||
|
||||
# Prompt
|
||||
system = """You a question re-writer that converts an input question to a better version that is optimized \n
|
||||
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
|
||||
re_write_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", system),
|
||||
("system", system_question_rewriter),
|
||||
(
|
||||
"human",
|
||||
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
question_rewriter = re_write_prompt | llm | StrOutputParser()
|
||||
|
||||
|
||||
@ -372,8 +329,6 @@ def go(args):
|
||||
|
||||
|
||||
### Edges ###
|
||||
|
||||
|
||||
def route_question(state):
|
||||
"""
|
||||
Route question to web search or RAG.
|
||||
@ -504,8 +459,6 @@ def go(args):
|
||||
# Compile
|
||||
app = workflow.compile()
|
||||
|
||||
|
||||
|
||||
# Run
|
||||
inputs = {
|
||||
"question": args.query
|
||||
@ -519,10 +472,71 @@ def go(args):
|
||||
pprint("\n---\n")
|
||||
|
||||
# Final generation
|
||||
pprint(value["generation"])
|
||||
print(value["generation"])
|
||||
|
||||
return {"response": value["generation"]}
|
||||
|
||||
def go_evaluation(args):
|
||||
if args.evaluation_dataset_csv_path:
|
||||
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_csv(args.evaluation_dataset_csv_path)
|
||||
dataset_name = os.path.basename(args.evaluation_dataset_csv_path).split('.')[0]
|
||||
|
||||
# df contains columns of question and answer
|
||||
examples = df[[args.evaluation_dataset_column_question, args.evaluation_dataset_column_answer]].values.tolist()
|
||||
inputs = [{"question": input_prompt} for input_prompt, _ in examples]
|
||||
outputs = [{"answer": output_answer} for _, output_answer in examples]
|
||||
|
||||
# Programmatically create a dataset in LangSmith
|
||||
client = Client()
|
||||
|
||||
try:
|
||||
# Create a dataset
|
||||
dataset = client.create_dataset(
|
||||
dataset_name = dataset_name,
|
||||
description = "An evaluation dataset in LangSmith."
|
||||
)
|
||||
# Add examples to the dataset
|
||||
client.create_examples(inputs=inputs, outputs=outputs, dataset_id=dataset.id)
|
||||
except langsmith.utils.LangSmithConflictError:
|
||||
pass
|
||||
|
||||
|
||||
args.ls_chat_model_evaluator = None if args.ls_chat_model_evaluator == 'None' else args.ls_chat_model_evaluator.split(',')
|
||||
|
||||
def target(inputs: dict) -> dict:
|
||||
new_args = argparse.Namespace(**vars(args))
|
||||
new_args.query = inputs["question"]
|
||||
return go(new_args)
|
||||
|
||||
ls_evaluators = []
|
||||
if args.ls_chat_model_evaluator:
|
||||
for evaluator in args.ls_chat_model_evaluator:
|
||||
if evaluator == 'moonshot':
|
||||
ls_evaluators.append(moonshot_evaluator_correctness)
|
||||
ls_evaluators.append(moonshot_evaluator_faithfulness)
|
||||
elif evaluator == 'deepseek':
|
||||
ls_evaluators.append(deepseek_evaluator_correctness)
|
||||
ls_evaluators.append(deepseek_evaluator_faithfulness)
|
||||
elif evaluator == 'gemini':
|
||||
ls_evaluators.append(gemini_evaluator_correctness)
|
||||
ls_evaluators.append(gemini_evaluator_faithfulness)
|
||||
|
||||
# After running the evaluation, a link will be provided to view the results in langsmith
|
||||
experiment_results = client.evaluate(
|
||||
target,
|
||||
data = dataset_name,
|
||||
evaluators = ls_evaluators,
|
||||
experiment_prefix = "first-eval-in-langsmith",
|
||||
max_concurrency = 1,
|
||||
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
|
||||
parser = argparse.ArgumentParser(description="Adaptive AG")
|
||||
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
@ -531,6 +545,27 @@ if __name__ == "__main__":
|
||||
required=True
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--evaluation_dataset_csv_path",
|
||||
type=str,
|
||||
help="Path to the query evaluation dataset",
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--evaluation_dataset_column_question",
|
||||
type=str,
|
||||
help="Column name for the questions in the evaluation dataset",
|
||||
default="question",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--evaluation_dataset_column_answer",
|
||||
type=str,
|
||||
help="Column name for the groundtruth answers in the evaluation dataset",
|
||||
default="groundtruth",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_chromadb_artifact",
|
||||
type=str,
|
||||
@ -552,6 +587,14 @@ if __name__ == "__main__":
|
||||
help="Chat model provider"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ls_chat_model_evaluator",
|
||||
type=str,
|
||||
help="list of Chat model providers for evaluation",
|
||||
required=False,
|
||||
default="None"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
go(args)
|
||||
|
||||
go_evaluation(args)
|
||||
@ -14,10 +14,14 @@ from langchain_community.llms.moonshot import Moonshot
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["GOOGLE_API_KEY"] = config("GOOGLE_API_KEY", cast=str)
|
||||
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
|
||||
os.environ["MOONSHOT_API_KEY"] = config("MOONSHOT_API_KEY", cast=str)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
|
||||
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
|
||||
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
|
||||
os.environ["LANGSMITH_API_KEY"] = config("LANGSMITH_API_KEY", cast=str)
|
||||
os.environ["LANGSMITH_TRACING"] = config("LANGSMITH_TRACING", cast=str)
|
||||
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
|
||||
os.environ["LANGSMITH_PROJECT"] = config("LANGSMITH_PROJECT", cast=str)
|
||||
|
||||
def go(args):
|
||||
|
||||
@ -60,14 +64,12 @@ def go(args):
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=DEEKSEEK_API_KEY
|
||||
)
|
||||
|
||||
elif args.chat_model_provider == "gemini":
|
||||
# Initialize Gemini model
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
model="gemini-1.5-flash",
|
||||
google_api_key=GEMINI_API_KEY,
|
||||
temperature=0,
|
||||
max_retries=3
|
||||
)
|
||||
@ -80,7 +82,6 @@ def go(args):
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
api_key=MOONSHOT_API_KEY
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -4,4 +4,6 @@ from sentence_transformers import SentenceTransformer
|
||||
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str, default="paraphrase-multilingual-mpnet-base-v2")
|
||||
|
||||
# Initialize embedding model
|
||||
model = SentenceTransformer(EMBEDDING_MODEL)
|
||||
model = SentenceTransformer(EMBEDDING_MODEL)
|
||||
|
||||
model.save("./transformer_model/paraphrase-multilingual-mpnet-base-v2")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user