Merge pull request #19 from aimingmed/feature/langsmith-evaluation

Feature/langsmith evaluation
This commit is contained in:
Hong Kai LEE 2025-04-01 10:47:09 +08:00 committed by GitHub
commit 465c24546d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 483 additions and 153 deletions

3
.gitignore vendored
View File

@ -208,4 +208,5 @@ data/*
**/*.zip
**/llm-examples/*
**/*.ipynb_checkpoints
**/*.ipynb
**/*.ipynb
**/transformer_model/*

View File

@ -9,8 +9,16 @@ etl:
path_document_folder: "../../../../data"
run_id_documents: None
embedding_model: paraphrase-multilingual-mpnet-base-v2
prompt_engineering:
rag:
run_id_chromadb: None
chat_model_provider: gemini
chat_model_provider: deepseek
testing:
query: "如何治疗乳腺癌?"
evaluation:
evaluation_dataset_csv_path: "../../../../data/qa_dataset_20240321a.csv"
evaluation_dataset_column_question: question
evaluation_dataset_column_answer: answer
ls_chat_model_provider:
- gemini
- deepseek
- moonshot

View File

@ -9,7 +9,7 @@ _steps = [
"etl_chromadb_pdf",
"etl_chromadb_scanned_pdf", # the performance for scanned pdf may not be good
"rag_cot_evaluation",
"adaptive_rag_evaluation",
"rag_adaptive_evaluation",
"test_rag_cot"
]
@ -104,7 +104,7 @@ def go(config: DictConfig):
)
if "rag_cot_evaluation" in active_steps:
if config["prompt_engineering"]["run_id_chromadb"] == "None":
if config["rag"]["run_id_chromadb"] == "None":
# Look for run_id that has artifact logged as documents
run_id = None
client = mlflow.tracking.MlflowClient()
@ -119,22 +119,22 @@ def go(config: DictConfig):
if run_id is None:
raise ValueError("No run_id found with artifact logged as documents")
else:
run_id = config["prompt_engineering"]["run_id_chromadb"]
run_id = config["rag"]["run_id_chromadb"]
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation"),
"main",
parameters={
"query": config["prompt_engineering"]["query"],
"query": config["testing"]["query"],
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
"embedding_model": config["etl"]["embedding_model"],
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
"chat_model_provider": config["rag"]["chat_model_provider"]
},
)
if "adaptive_rag_evaluation" in active_steps:
if "rag_adaptive_evaluation" in active_steps:
if config["prompt_engineering"]["run_id_chromadb"] == "None":
if config["rag"]["run_id_chromadb"] == "None":
# Look for run_id that has artifact logged as documents
run_id = None
client = mlflow.tracking.MlflowClient()
@ -149,16 +149,20 @@ def go(config: DictConfig):
if run_id is None:
raise ValueError("No run_id found with artifact logged as documents")
else:
run_id = config["prompt_engineering"]["run_id_chromadb"]
run_id = config["rag"]["run_id_chromadb"]
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), "src", "adaptive_rag_evaluation"),
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_adaptive_evaluation"),
"main",
parameters={
"query": config["prompt_engineering"]["query"],
"query": config["testing"]["query"],
"evaluation_dataset_csv_path": config["evaluation"]["evaluation_dataset_csv_path"],
"evaluation_dataset_column_question": config["evaluation"]["evaluation_dataset_column_question"],
"evaluation_dataset_column_answer": config["evaluation"]["evaluation_dataset_column_answer"],
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
"embedding_model": config["etl"]["embedding_model"],
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
"chat_model_provider": config["rag"]["chat_model_provider"],
"ls_chat_model_evaluator": ','.join(config["evaluation"]["ls_chat_model_provider"]) if config["evaluation"]["ls_chat_model_provider"] is not None else 'None',
},
)
@ -168,10 +172,10 @@ def go(config: DictConfig):
os.path.join(hydra.utils.get_original_cwd(), "components", "test_rag_cot"),
"main",
parameters={
"query": config["prompt_engineering"]["query"],
"query": config["testing"]["query"],
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation", "chroma_db"),
"embedding_model": config["etl"]["embedding_model"],
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
"chat_model_provider": config["rag"]["chat_model_provider"]
},
)

View File

@ -1,29 +0,0 @@
name: adaptive_rag_evaluation
python_env: python_env.yml
entry_points:
main:
parameters:
query:
description: Query to run
type: string
input_chromadb_artifact:
description: Fully-qualified name for the input artifact
type: string
embedding_model:
description: Fully-qualified name for the embedding model
type: string
chat_model_provider:
description: Fully-qualified name for the chat model provider
type: string
command: >-
python run.py --query {query} \
--input_chromadb_artifact {input_chromadb_artifact} \
--embedding_model {embedding_model} \
--chat_model_provider {chat_model_provider}

View File

@ -105,7 +105,7 @@ def go(args):
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=1000, chunk_overlap=500
chunk_size=15000, chunk_overlap=7500
)
ls_docs = []
@ -113,7 +113,7 @@ def go(args):
for file in files:
if file.endswith(".pdf"):
read_text = extract_chinese_text_from_pdf(os.path.join(root, file))
document = Document(metadata={"file": file}, page_content=read_text)
document = Document(metadata={"file": f"{documents_folder}/{file}"}, page_content=read_text)
ls_docs.append(document)
doc_splits = text_splitter.split_documents(ls_docs)
@ -138,7 +138,7 @@ def go(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="A very basic data cleaning")
parser = argparse.ArgumentParser(description="ETL for ChromaDB with readable PDF")
parser.add_argument(
"--input_artifact",

View File

@ -0,0 +1,49 @@
name: rag_adaptive_evaluation
python_env: python_env.yml
entry_points:
main:
parameters:
query:
description: Query to run
type: string
evaluation_dataset_csv_path:
description: query evaluation dataset csv path
type: string
evaluation_dataset_column_question:
description: query evaluation dataset column question
type: string
evaluation_dataset_column_answer:
description: query evaluation dataset column groundtruth
type: string
input_chromadb_artifact:
description: Fully-qualified name for the input artifact
type: string
embedding_model:
description: Fully-qualified name for the embedding model
type: string
chat_model_provider:
description: Fully-qualified name for the chat model provider
type: string
ls_chat_model_evaluator:
description: list of chat model providers for evaluation
type: string
command: >-
python run.py --query {query} \
--evaluation_dataset_csv_path {evaluation_dataset_csv_path} \
--evaluation_dataset_column_question {evaluation_dataset_column_question} \
--evaluation_dataset_column_answer {evaluation_dataset_column_answer} \
--input_chromadb_artifact {input_chromadb_artifact} \
--embedding_model {embedding_model} \
--chat_model_provider {chat_model_provider} \
--ls_chat_model_evaluator {ls_chat_model_evaluator}

View File

@ -0,0 +1,32 @@
from typing import Literal
from pydantic import BaseModel, Field
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)

View File

@ -0,0 +1,141 @@
import os
from decouple import config
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot
from pydantic import BaseModel, Field
from prompts_library import CORRECTNESS_PROMPT, FAITHFULNESS_PROMPT
os.environ["GOOGLE_API_KEY"] = config("GOOGLE_API_KEY", cast=str)
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
os.environ["MOONSHOT_API_KEY"] = config("MOONSHOT_API_KEY", cast=str)
# Define output schema for the evaluation
class CorrectnessGrade(BaseModel):
score: int = Field(description="Numerical score (1-5) indicating the correctness of the response.")
class FaithfulnessGrade(BaseModel):
score: int = Field(description="Numerical score (1-5) indicating the faithfulness of the response.")
# Evaluators
def gemini_evaluator_correctness(outputs: dict, reference_outputs: dict) -> CorrectnessGrade:
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=0.5,
)
messages = [
{"role": "system", "content": CORRECTNESS_PROMPT},
{"role": "user", "content": f"""Ground Truth answer: {reference_outputs["answer"]};
Student's Answer: {outputs['response']}
"""}
]
response = llm.invoke(messages)
return CorrectnessGrade(score=int(response.content)).score
def deepseek_evaluator_correctness(outputs: dict, reference_outputs: dict) -> CorrectnessGrade:
llm = ChatDeepSeek(
model="deepseek-chat",
temperature=0.5,
)
messages = [
{"role": "system", "content": CORRECTNESS_PROMPT},
{"role": "user", "content": f"""Ground Truth answer: {reference_outputs["answer"]};
Student's Answer: {outputs['response']}
"""}
]
response = llm.invoke(messages)
return CorrectnessGrade(score=int(response.content)).score
def moonshot_evaluator_correctness(outputs: dict, reference_outputs: dict) -> CorrectnessGrade:
llm = Moonshot(
model="moonshot-v1-128k",
temperature=0.5,
)
messages = [
{"role": "system", "content": CORRECTNESS_PROMPT},
{"role": "user", "content": f"""Ground Truth answer: {reference_outputs["answer"]};
Student's Answer: {outputs['response']}
"""}
]
response = llm.invoke(messages)
try:
return CorrectnessGrade(score=int(response)).score
except ValueError:
score_str = response.split(":")[1].strip()
return CorrectnessGrade(score=int(score_str)).score
def gemini_evaluator_faithfulness(outputs: dict, reference_outputs: dict) -> FaithfulnessGrade:
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
temperature=0.5,
)
messages = [
{"role": "system", "content": FAITHFULNESS_PROMPT},
{"role": "user", "content": f"""Context: {reference_outputs["answer"]};
Output: {outputs['response']}
"""}
]
response = llm.invoke(messages)
return FaithfulnessGrade(score=int(response.content)).score
def deepseek_evaluator_faithfulness(outputs: dict, reference_outputs: dict) -> FaithfulnessGrade:
llm = ChatDeepSeek(
model="deepseek-chat",
temperature=0.5,
)
messages = [
{"role": "system", "content": FAITHFULNESS_PROMPT},
{"role": "user", "content": f"""Context: {reference_outputs["answer"]};
Output: {outputs['response']}
"""}
]
response = llm.invoke(messages)
return FaithfulnessGrade(score=int(response.content)).score
def moonshot_evaluator_faithfulness(outputs: dict, reference_outputs: dict) -> FaithfulnessGrade:
llm = Moonshot(
model="moonshot-v1-128k",
temperature=0.5,
)
messages = [
{"role": "system", "content": FAITHFULNESS_PROMPT},
{"role": "user", "content": f"""Context: {reference_outputs["answer"]};
Output: {outputs['response']}
"""}
]
response = llm.invoke(messages)
try:
return FaithfulnessGrade(score=int(response)).score
except ValueError:
score_str = response.split(":")[1].strip()
return FaithfulnessGrade(score=int(score_str)).score

View File

@ -0,0 +1,78 @@
system_router = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to any cancer/tumor disease. The question may be
asked in a variety of languages, and may be phrased in a variety of ways.
Use the vectorstore for questions on these topics. Otherwise, use web-search.
"""
system_retriever_grader = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
system_hallucination_grader = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
system_answer_grader = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
system_question_rewriter = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
# Evaluation
CORRECTNESS_PROMPT = """You are an impartial judge. Evaluate Student Answer against Ground Truth for conceptual similarity and correctness.
You may also be given additional information that was used by the model to generate the output.
Your task is to determine a numerical score called correctness based on the Student Answer and Ground Truth.
A definition of correctness and a grading rubric are provided below.
You must use the grading rubric to determine your score.
Metric definition:
Correctness assesses the degree to which a provided Student Answer aligns with factual accuracy, completeness, logical
consistency, and precise terminology of the Ground Truth. It evaluates the intrinsic validity of the Student Answer , independent of any
external context. A higher score indicates a higher adherence to factual accuracy, completeness, logical consistency,
and precise terminology of the Ground Truth.
Grading rubric:
Correctness: Below are the details for different scores:
- 1: Major factual errors, highly incomplete, illogical, and uses incorrect terminology.
- 2: Significant factual errors, incomplete, noticeable logical flaws, and frequent terminology errors.
- 3: Minor factual errors, somewhat incomplete, minor logical inconsistencies, and occasional terminology errors.
- 4: Few to no factual errors, mostly complete, strong logical consistency, and accurate terminology.
- 5: Accurate, complete, logically consistent, and uses precise terminology.
Reminder:
- Carefully read the Student Answer and Ground Truth
- Check for factual accuracy and completeness of Student Answer compared to the Ground Truth
- Focus on correctness of information rather than style or verbosity
- The goal is to evaluate factual correctness and completeness of the Student Answer.
- Please provide your answer score only with the numerical number between 1 and 5. No score: or other text is allowed.
"""
FAITHFULNESS_PROMPT = """You are an impartial judge. Evaluate output against context for faithfulness.
You may also be given additional information that was used by the model to generate the Output.
Your task is to determine a numerical score called faithfulness based on the output and context.
A definition of faithfulness and a grading rubric are provided below.
You must use the grading rubric to determine your score.
Metric definition:
Faithfulness is only evaluated with the provided output and context. Faithfulness assesses how much of the
provided output is factually consistent with the provided context. A higher score indicates that a higher proportion of
claims present in the output can be derived from the provided context. Faithfulness does not consider how much extra
information from the context is not present in the output.
Grading rubric:
Faithfulness: Below are the details for different scores:
- Score 1: None of the claims in the output can be inferred from the provided context.
- Score 2: Some of the claims in the output can be inferred from the provided context, but the majority of the output is missing from, inconsistent with, or contradictory to the provided context.
- Score 3: Half or more of the claims in the output can be inferred from the provided context.
- Score 4: Most of the claims in the output can be inferred from the provided context, with very little information that is not directly supported by the provided context.
- Score 5: All of the claims in the output are directly supported by the provided context, demonstrating high faithfulness to the provided context.
Reminder:
- Carefully read the output and context
- Focus on the information instead of the writing style or verbosity.
- Please provide your answer score only with the numerical number between 1 and 5, according to the grading rubric above. No score: or other text is allowed.
"""

View File

@ -3,41 +3,61 @@ import logging
import argparse
import mlflow
import shutil
import langsmith
from decouple import config
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores.chroma import Chroma
from typing import Literal, List
from typing import List
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.schema import Document
from pprint import pprint
from langgraph.graph import END, StateGraph, START
from langsmith import Client
from data_models import (
RouteQuery,
GradeDocuments,
GradeHallucinations,
GradeAnswer
)
from prompts_library import (
system_router,
system_retriever_grader,
system_hallucination_grader,
system_answer_grader,
system_question_rewriter
)
from evaluators import (
gemini_evaluator_correctness,
deepseek_evaluator_correctness,
moonshot_evaluator_correctness,
gemini_evaluator_faithfulness,
deepseek_evaluator_faithfulness,
moonshot_evaluator_faithfulness
)
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
TAVILY_API_KEY = config("TAVILY_API_KEY", cast=str)
LANGSMITH_API_KEY = config("LANGSMITH_API_KEY", cast=str)
LANGSMITH_TRACING = config("LANGSMITH_TRACING", cast=str)
LANGSMITH_PROJECT = config("LANGSMITH_PROJECT", cast=str)
os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY
os.environ["GOOGLE_API_KEY"] = config("GOOGLE_API_KEY", cast=str)
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
os.environ["MOONSHOT_API_KEY"] = config("MOONSHOT_API_KEY", cast=str)
os.environ["TAVILY_API_KEY"] = config("TAVILY_API_KEY", cast=str)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["LANGSMITH_API_KEY"] = LANGSMITH_API_KEY
os.environ["LANGSMITH_TRACING"] = LANGSMITH_TRACING
os.environ["LANGSMITH_API_KEY"] = config("LANGSMITH_API_KEY", cast=str)
os.environ["LANGSMITH_TRACING"] = config("LANGSMITH_TRACING", cast=str)
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_PROJECT"] = LANGSMITH_PROJECT
os.environ["LANGSMITH_PROJECT"] = config("LANGSMITH_PROJECT", cast=str)
def go(args):
@ -71,12 +91,10 @@ def go(args):
max_tokens=None,
timeout=None,
max_retries=2,
api_key=DEEKSEEK_API_KEY
)
elif args.chat_model_provider == 'gemini':
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0,
max_retries=3,
streaming=True
@ -88,7 +106,6 @@ def go(args):
max_tokens=None,
timeout=None,
max_retries=2,
api_key=MOONSHOT_API_KEY
)
# Load data from ChromaDB
@ -98,61 +115,32 @@ def go(args):
vectorstore = Chroma(persist_directory=db_path, collection_name=collection_name, embedding_function=embedding_model)
retriever = vectorstore.as_retriever()
# Data model
class RouteQuery(BaseModel):
"""Route a user query to the most relevant datasource."""
datasource: Literal["vectorstore", "web_search"] = Field(
...,
description="Given a user question choose to route it to web search or a vectorstore.",
)
##########################################
# Routing to vectorstore or web search
structured_llm_router = llm.with_structured_output(RouteQuery)
# Prompt
system = """You are an expert at routing a user question to a vectorstore or web search.
The vectorstore contains documents related to any cancer/tumor disease. The question may be
asked in a variety of languages, and may be phrased in a variety of ways.
Use the vectorstore for questions on these topics. Otherwise, use web-search.
"""
route_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_router),
("human", "{question}"),
]
)
question_router = route_prompt | structured_llm_router
##########################################
### Retrieval Grader
# Data model
class GradeDocuments(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
# Prompt
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_retriever_grader),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
##########################################
### Generate
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
@ -167,76 +155,45 @@ def go(args):
rag_chain = prompt | llm | StrOutputParser()
##########################################
### Hallucination Grader
# Data model
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
# LLM with function call
structured_llm_grader = llm.with_structured_output(GradeHallucinations)
# Prompt
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_hallucination_grader),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
]
)
hallucination_grader = hallucination_prompt | structured_llm_grader
##########################################
### Answer Grader
# Data model
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)
# LLM with function call
structured_llm_grader = llm.with_structured_output(GradeAnswer)
# Prompt
system = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_answer_grader),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grader
##########################################
### Question Re-writer
# LLM
# Prompt
system = """You a question re-writer that converts an input question to a better version that is optimized \n
for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("system", system_question_rewriter),
(
"human",
"Here is the initial question: \n\n {question} \n Formulate an improved question.",
),
]
)
question_rewriter = re_write_prompt | llm | StrOutputParser()
@ -372,8 +329,6 @@ def go(args):
### Edges ###
def route_question(state):
"""
Route question to web search or RAG.
@ -504,8 +459,6 @@ def go(args):
# Compile
app = workflow.compile()
# Run
inputs = {
"question": args.query
@ -519,10 +472,71 @@ def go(args):
pprint("\n---\n")
# Final generation
pprint(value["generation"])
print(value["generation"])
return {"response": value["generation"]}
def go_evaluation(args):
if args.evaluation_dataset_csv_path:
import pandas as pd
df = pd.read_csv(args.evaluation_dataset_csv_path)
dataset_name = os.path.basename(args.evaluation_dataset_csv_path).split('.')[0]
# df contains columns of question and answer
examples = df[[args.evaluation_dataset_column_question, args.evaluation_dataset_column_answer]].values.tolist()
inputs = [{"question": input_prompt} for input_prompt, _ in examples]
outputs = [{"answer": output_answer} for _, output_answer in examples]
# Programmatically create a dataset in LangSmith
client = Client()
try:
# Create a dataset
dataset = client.create_dataset(
dataset_name = dataset_name,
description = "An evaluation dataset in LangSmith."
)
# Add examples to the dataset
client.create_examples(inputs=inputs, outputs=outputs, dataset_id=dataset.id)
except langsmith.utils.LangSmithConflictError:
pass
args.ls_chat_model_evaluator = None if args.ls_chat_model_evaluator == 'None' else args.ls_chat_model_evaluator.split(',')
def target(inputs: dict) -> dict:
new_args = argparse.Namespace(**vars(args))
new_args.query = inputs["question"]
return go(new_args)
ls_evaluators = []
if args.ls_chat_model_evaluator:
for evaluator in args.ls_chat_model_evaluator:
if evaluator == 'moonshot':
ls_evaluators.append(moonshot_evaluator_correctness)
ls_evaluators.append(moonshot_evaluator_faithfulness)
elif evaluator == 'deepseek':
ls_evaluators.append(deepseek_evaluator_correctness)
ls_evaluators.append(deepseek_evaluator_faithfulness)
elif evaluator == 'gemini':
ls_evaluators.append(gemini_evaluator_correctness)
ls_evaluators.append(gemini_evaluator_faithfulness)
# After running the evaluation, a link will be provided to view the results in langsmith
experiment_results = client.evaluate(
target,
data = dataset_name,
evaluators = ls_evaluators,
experiment_prefix = "first-eval-in-langsmith",
max_concurrency = 1,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
parser = argparse.ArgumentParser(description="Adaptive AG")
parser.add_argument(
"--query",
@ -531,6 +545,27 @@ if __name__ == "__main__":
required=True
)
parser.add_argument(
"--evaluation_dataset_csv_path",
type=str,
help="Path to the query evaluation dataset",
default=None,
)
parser.add_argument(
"--evaluation_dataset_column_question",
type=str,
help="Column name for the questions in the evaluation dataset",
default="question",
)
parser.add_argument(
"--evaluation_dataset_column_answer",
type=str,
help="Column name for the groundtruth answers in the evaluation dataset",
default="groundtruth",
)
parser.add_argument(
"--input_chromadb_artifact",
type=str,
@ -552,6 +587,14 @@ if __name__ == "__main__":
help="Chat model provider"
)
parser.add_argument(
"--ls_chat_model_evaluator",
type=str,
help="list of Chat model providers for evaluation",
required=False,
default="None"
)
args = parser.parse_args()
go(args)
go_evaluation(args)

View File

@ -14,10 +14,14 @@ from langchain_community.llms.moonshot import Moonshot
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
os.environ["GOOGLE_API_KEY"] = config("GOOGLE_API_KEY", cast=str)
os.environ["DEEPSEEK_API_KEY"] = config("DEEPSEEK_API_KEY", cast=str)
os.environ["MOONSHOT_API_KEY"] = config("MOONSHOT_API_KEY", cast=str)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
os.environ["LANGSMITH_API_KEY"] = config("LANGSMITH_API_KEY", cast=str)
os.environ["LANGSMITH_TRACING"] = config("LANGSMITH_TRACING", cast=str)
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_PROJECT"] = config("LANGSMITH_PROJECT", cast=str)
def go(args):
@ -60,14 +64,12 @@ def go(args):
max_tokens=None,
timeout=None,
max_retries=2,
api_key=DEEKSEEK_API_KEY
)
elif args.chat_model_provider == "gemini":
# Initialize Gemini model
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0,
max_retries=3
)
@ -80,7 +82,6 @@ def go(args):
max_tokens=None,
timeout=None,
max_retries=2,
api_key=MOONSHOT_API_KEY
)

View File

@ -4,4 +4,6 @@ from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = config("EMBEDDING_MODEL", cast=str, default="paraphrase-multilingual-mpnet-base-v2")
# Initialize embedding model
model = SentenceTransformer(EMBEDDING_MODEL)
model = SentenceTransformer(EMBEDDING_MODEL)
model.save("./transformer_model/paraphrase-multilingual-mpnet-base-v2")