mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-08 00:03:15 +08:00
update for now
This commit is contained in:
parent
b6ca6ac677
commit
86a2c1a055
@ -9,8 +9,15 @@ etl:
|
|||||||
path_document_folder: "../../../../data"
|
path_document_folder: "../../../../data"
|
||||||
run_id_documents: None
|
run_id_documents: None
|
||||||
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
embedding_model: paraphrase-multilingual-mpnet-base-v2
|
||||||
prompt_engineering:
|
rag:
|
||||||
run_id_chromadb: None
|
run_id_chromadb: None
|
||||||
chat_model_provider: gemini
|
chat_model_provider: gemini
|
||||||
|
testing:
|
||||||
query: "如何治疗乳腺癌?"
|
query: "如何治疗乳腺癌?"
|
||||||
query_evaluation_dataset_csv_path: "../../../../data/qa_datasets.csv"
|
evaluation:
|
||||||
|
evaluation_dataset_csv_path: "../../../../data/qa_datasets.csv"
|
||||||
|
evaluation_dataset_column_question: question
|
||||||
|
evaluation_dataset_column_answer: answer
|
||||||
|
ls_chat_model_provider:
|
||||||
|
- gemini
|
||||||
|
- moonshot
|
||||||
@ -104,7 +104,7 @@ def go(config: DictConfig):
|
|||||||
)
|
)
|
||||||
if "rag_cot_evaluation" in active_steps:
|
if "rag_cot_evaluation" in active_steps:
|
||||||
|
|
||||||
if config["prompt_engineering"]["run_id_chromadb"] == "None":
|
if config["rag"]["run_id_chromadb"] == "None":
|
||||||
# Look for run_id that has artifact logged as documents
|
# Look for run_id that has artifact logged as documents
|
||||||
run_id = None
|
run_id = None
|
||||||
client = mlflow.tracking.MlflowClient()
|
client = mlflow.tracking.MlflowClient()
|
||||||
@ -119,22 +119,22 @@ def go(config: DictConfig):
|
|||||||
if run_id is None:
|
if run_id is None:
|
||||||
raise ValueError("No run_id found with artifact logged as documents")
|
raise ValueError("No run_id found with artifact logged as documents")
|
||||||
else:
|
else:
|
||||||
run_id = config["prompt_engineering"]["run_id_chromadb"]
|
run_id = config["rag"]["run_id_chromadb"]
|
||||||
|
|
||||||
_ = mlflow.run(
|
_ = mlflow.run(
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation"),
|
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation"),
|
||||||
"main",
|
"main",
|
||||||
parameters={
|
parameters={
|
||||||
"query": config["prompt_engineering"]["query"],
|
"query": config["testing"]["query"],
|
||||||
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
|
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
|
||||||
"embedding_model": config["etl"]["embedding_model"],
|
"embedding_model": config["etl"]["embedding_model"],
|
||||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
"chat_model_provider": config["rag"]["chat_model_provider"]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if "rag_adaptive_evaluation" in active_steps:
|
if "rag_adaptive_evaluation" in active_steps:
|
||||||
|
|
||||||
if config["prompt_engineering"]["run_id_chromadb"] == "None":
|
if config["rag"]["run_id_chromadb"] == "None":
|
||||||
# Look for run_id that has artifact logged as documents
|
# Look for run_id that has artifact logged as documents
|
||||||
run_id = None
|
run_id = None
|
||||||
client = mlflow.tracking.MlflowClient()
|
client = mlflow.tracking.MlflowClient()
|
||||||
@ -149,17 +149,20 @@ def go(config: DictConfig):
|
|||||||
if run_id is None:
|
if run_id is None:
|
||||||
raise ValueError("No run_id found with artifact logged as documents")
|
raise ValueError("No run_id found with artifact logged as documents")
|
||||||
else:
|
else:
|
||||||
run_id = config["prompt_engineering"]["run_id_chromadb"]
|
run_id = config["rag"]["run_id_chromadb"]
|
||||||
|
|
||||||
_ = mlflow.run(
|
_ = mlflow.run(
|
||||||
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_adaptive_evaluation"),
|
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_adaptive_evaluation"),
|
||||||
"main",
|
"main",
|
||||||
parameters={
|
parameters={
|
||||||
"query": config["prompt_engineering"]["query"],
|
"query": config["testing"]["query"],
|
||||||
"query_evaluation_dataset_csv_path": config["prompt_engineering"]["query_evaluation_dataset_csv_path"],
|
"evaluation_dataset_csv_path": config["evaluation"]["evaluation_dataset_csv_path"],
|
||||||
|
"evaluation_dataset_column_question": config["evaluation"]["evaluation_dataset_column_question"],
|
||||||
|
"evaluation_dataset_column_answer": config["evaluation"]["evaluation_dataset_column_answer"],
|
||||||
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
|
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
|
||||||
"embedding_model": config["etl"]["embedding_model"],
|
"embedding_model": config["etl"]["embedding_model"],
|
||||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
"chat_model_provider": config["rag"]["chat_model_provider"],
|
||||||
|
"ls_chat_model_evaluator": ','.join(config["evaluation"]["ls_chat_model_provider"]) if config["evaluation"]["ls_chat_model_provider"] is not None else 'None',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -169,10 +172,10 @@ def go(config: DictConfig):
|
|||||||
os.path.join(hydra.utils.get_original_cwd(), "components", "test_rag_cot"),
|
os.path.join(hydra.utils.get_original_cwd(), "components", "test_rag_cot"),
|
||||||
"main",
|
"main",
|
||||||
parameters={
|
parameters={
|
||||||
"query": config["prompt_engineering"]["query"],
|
"query": config["testing"]["query"],
|
||||||
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation", "chroma_db"),
|
"input_chromadb_local": os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot_evaluation", "chroma_db"),
|
||||||
"embedding_model": config["etl"]["embedding_model"],
|
"embedding_model": config["etl"]["embedding_model"],
|
||||||
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
|
"chat_model_provider": config["rag"]["chat_model_provider"]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -9,10 +9,18 @@ entry_points:
|
|||||||
description: Query to run
|
description: Query to run
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
query_evaluation_dataset_csv_path:
|
evaluation_dataset_csv_path:
|
||||||
description: query evaluation dataset csv path
|
description: query evaluation dataset csv path
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
|
evaluation_dataset_column_question:
|
||||||
|
description: query evaluation dataset column question
|
||||||
|
type: string
|
||||||
|
|
||||||
|
evaluation_dataset_column_answer:
|
||||||
|
description: query evaluation dataset column groundtruth
|
||||||
|
type: string
|
||||||
|
|
||||||
input_chromadb_artifact:
|
input_chromadb_artifact:
|
||||||
description: Fully-qualified name for the input artifact
|
description: Fully-qualified name for the input artifact
|
||||||
type: string
|
type: string
|
||||||
@ -24,10 +32,18 @@ entry_points:
|
|||||||
chat_model_provider:
|
chat_model_provider:
|
||||||
description: Fully-qualified name for the chat model provider
|
description: Fully-qualified name for the chat model provider
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
|
ls_chat_model_evaluator:
|
||||||
|
description: list of chat model providers for evaluation
|
||||||
|
type: string
|
||||||
|
|
||||||
|
|
||||||
command: >-
|
command: >-
|
||||||
python run.py --query {query} \
|
python run.py --query {query} \
|
||||||
--query_evaluation_dataset_csv_path {query_evaluation_dataset_csv_path} \
|
--evaluation_dataset_csv_path {evaluation_dataset_csv_path} \
|
||||||
|
--evaluation_dataset_column_question {evaluation_dataset_column_question} \
|
||||||
|
--evaluation_dataset_column_answer {evaluation_dataset_column_answer} \
|
||||||
--input_chromadb_artifact {input_chromadb_artifact} \
|
--input_chromadb_artifact {input_chromadb_artifact} \
|
||||||
--embedding_model {embedding_model} \
|
--embedding_model {embedding_model} \
|
||||||
--chat_model_provider {chat_model_provider}
|
--chat_model_provider {chat_model_provider} \
|
||||||
|
--ls_chat_model_evaluator {ls_chat_model_evaluator}
|
||||||
@ -474,61 +474,60 @@ def go(args):
|
|||||||
return {"response": value["generation"]}
|
return {"response": value["generation"]}
|
||||||
|
|
||||||
def go_evaluation(args):
|
def go_evaluation(args):
|
||||||
if args.query_evaluation_dataset_csv_path:
|
if args.evaluation_dataset_csv_path:
|
||||||
# import pandas as pd
|
|
||||||
# from tqdm import tqdm
|
|
||||||
|
|
||||||
# df = pd.read_csv(args.query_evaluation_dataset_csv_path)
|
import pandas as pd
|
||||||
|
|
||||||
|
df = pd.read_csv(args.evaluation_dataset_csv_path)
|
||||||
|
dataset_name = os.path.basename(args.evaluation_dataset_csv_path).split('.')[0]
|
||||||
|
|
||||||
|
# df contains columns of question and answer
|
||||||
|
examples = df[[args.evaluation_dataset_column_question, args.evaluation_dataset_column_answer]].values.tolist()
|
||||||
|
inputs = [{"question": input_prompt} for input_prompt, _ in examples]
|
||||||
|
outputs = [{"answer": output_answer} for _, output_answer in examples]
|
||||||
|
|
||||||
|
# Programmatically create a dataset in LangSmith
|
||||||
client = Client()
|
client = Client()
|
||||||
# # Create inputs and reference outputs
|
|
||||||
# examples = [
|
|
||||||
# (
|
|
||||||
# "Which country is Mount Kilimanjaro located in?",
|
|
||||||
# "Mount Kilimanjaro is located in Tanzania.",
|
|
||||||
# ),
|
|
||||||
# (
|
|
||||||
# "What is Earth's lowest point?",
|
|
||||||
# "Earth's lowest point is The Dead Sea.",
|
|
||||||
# ),
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# inputs = [{"question": input_prompt} for input_prompt, _ in examples]
|
dataset = client.create_dataset(
|
||||||
# outputs = [{"answer": output_answer} for _, output_answer in examples]
|
dataset_name = dataset_name,
|
||||||
|
description = "A sample dataset in LangSmith."
|
||||||
|
)
|
||||||
|
|
||||||
# # Programmatically create a dataset in LangSmith
|
# Add examples to the dataset
|
||||||
# dataset = client.create_dataset(
|
client.create_examples(inputs=inputs, outputs=outputs, dataset_id=dataset.id)
|
||||||
# dataset_name = "Sample dataset",
|
|
||||||
# description = "A sample dataset in LangSmith."
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # Add examples to the dataset
|
|
||||||
# client.create_examples(inputs=inputs, outputs=outputs, dataset_id=dataset.id)
|
args.ls_chat_model_evaluator = None if args.ls_chat_model_evaluator == 'None' else args.ls_chat_model_evaluator.split(',')
|
||||||
|
|
||||||
def target(inputs: dict) -> dict:
|
def target(inputs: dict) -> dict:
|
||||||
new_args = argparse.Namespace(**vars(args))
|
new_args = argparse.Namespace(**vars(args))
|
||||||
new_args.query = inputs["question"]
|
new_args.query = inputs["question"]
|
||||||
return go(new_args)
|
return go(new_args)
|
||||||
|
|
||||||
|
ls_evaluators = []
|
||||||
|
if args.ls_chat_model_evaluator:
|
||||||
|
for evaluator in args.ls_chat_model_evaluator:
|
||||||
|
if evaluator == 'moonshot':
|
||||||
|
ls_evaluators.append(moonshot_evaluator_correctness)
|
||||||
|
ls_evaluators.append(moonshot_evaluator_faithfulness)
|
||||||
|
elif evaluator == 'deepseek':
|
||||||
|
ls_evaluators.append(deepseek_evaluator_correctness)
|
||||||
|
ls_evaluators.append(deepseek_evaluator_faithfulness)
|
||||||
|
elif evaluator == 'gemini':
|
||||||
|
ls_evaluators.append(gemini_evaluator_correctness)
|
||||||
|
ls_evaluators.append(gemini_evaluator_faithfulness)
|
||||||
|
|
||||||
# After running the evaluation, a link will be provided to view the results in langsmith
|
# After running the evaluation, a link will be provided to view the results in langsmith
|
||||||
experiment_results = client.evaluate(
|
experiment_results = client.evaluate(
|
||||||
target,
|
target,
|
||||||
data = "Sample dataset",
|
data = "Sample dataset",
|
||||||
evaluators = [
|
evaluators = ls_evaluators,
|
||||||
moonshot_evaluator_correctness,
|
|
||||||
deepseek_evaluator_correctness,
|
|
||||||
gemini_evaluator_correctness,
|
|
||||||
gemini_evaluator_faithfulness,
|
|
||||||
deepseek_evaluator_faithfulness,
|
|
||||||
moonshot_evaluator_faithfulness
|
|
||||||
# can add multiple evaluators here
|
|
||||||
],
|
|
||||||
experiment_prefix = "first-eval-in-langsmith",
|
experiment_prefix = "first-eval-in-langsmith",
|
||||||
max_concurrency = 1,
|
max_concurrency = 1,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Adaptive AG")
|
parser = argparse.ArgumentParser(description="Adaptive AG")
|
||||||
@ -541,12 +540,26 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--query_evaluation_dataset_csv_path",
|
"--evaluation_dataset_csv_path",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to the query evaluation dataset",
|
help="Path to the query evaluation dataset",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--evaluation_dataset_column_question",
|
||||||
|
type=str,
|
||||||
|
help="Column name for the questions in the evaluation dataset",
|
||||||
|
default="question",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--evaluation_dataset_column_answer",
|
||||||
|
type=str,
|
||||||
|
help="Column name for the groundtruth answers in the evaluation dataset",
|
||||||
|
default="groundtruth",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input_chromadb_artifact",
|
"--input_chromadb_artifact",
|
||||||
type=str,
|
type=str,
|
||||||
@ -568,7 +581,14 @@ if __name__ == "__main__":
|
|||||||
help="Chat model provider"
|
help="Chat model provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ls_chat_model_evaluator",
|
||||||
|
type=str,
|
||||||
|
help="list of Chat model providers for evaluation",
|
||||||
|
required=False,
|
||||||
|
default="None"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# go(args)
|
|
||||||
go_evaluation(args)
|
go_evaluation(args)
|
||||||
Loading…
x
Reference in New Issue
Block a user