This commit is contained in:
leehk 2025-03-05 15:07:05 +08:00
parent 7399b56fa1
commit 3a4d59c0e3
16 changed files with 592 additions and 546 deletions

View File

@ -10,5 +10,4 @@ build_dependencies:
# Dependencies required to run the project.
dependencies:
- mlflow==2.8.1
- wandb==0.16.0
- git+https://github.com/udacity/nd0821-c2-build-model-workflow-starter.git#egg=wandb-utils&subdirectory=components

View File

@ -5,33 +5,33 @@ This script download a URL to a local destination
import argparse
import logging
import os
import wandb
from wandb_utils.log_artifact import log_artifact
import mlflow
import shutil
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
def go(args):
zip_path = os.path.join(args.path_document_folder, f"{args.document_folder}.zip")
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', args.path_document_folder, args.document_folder)
run = wandb.init(job_type="get_documents", entity='aimingmed')
run.config.update(args)
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id) as run:
logger.info(f"Uploading {args.artifact_name} to Weights & Biases")
log_artifact(
args.artifact_name,
args.artifact_type,
args.artifact_description,
zip_path,
run,
)
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
if 'artifact_description' not in existing_params:
mlflow.log_param('artifact_description', args.artifact_description)
if 'artifact_types' not in existing_params:
mlflow.log_param('artifact_types', args.artifact_type)
# Log parameters to MLflow
mlflow.log_params({
"input_artifact": args.artifact_name,
})
logger.info(f"Uploading {args.artifact_name} to MLFlow")
mlflow.log_artifact(zip_path, args.artifact_name)
if __name__ == "__main__":

View File

@ -7,8 +7,10 @@ etl:
input_artifact_name: documents
document_folder: documents
path_document_folder: "../../../../data"
run_id_documents: None
embedding_model: paraphrase-multilingual-mpnet-base-v2
prompt_engineering:
chat_model_provider: kimi
run_id_chromadb: None
chat_model_provider: moonshot
query: "怎么治疗有kras的肺癌?"

View File

@ -9,9 +9,9 @@ from decouple import config
_steps = [
"get_documents",
"etl_chromdb_pdf",
"etl_chromdb_scanned_pdf", # the performance for scanned pdf may not be good
"chain_of_thought"
"etl_chromadb_pdf",
"etl_chromadb_scanned_pdf", # the performance for scanned pdf may not be good
"rag_cot",
]
@ -19,9 +19,8 @@ _steps = [
@hydra.main(config_name='config')
def go(config: DictConfig):
# Setup the wandb experiment. All runs will be grouped under this name
os.environ["WANDB_PROJECT"] = config["main"]["project_name"]
os.environ["WANDB_RUN_GROUP"] = config["main"]["experiment_name"]
# Setup the MLflow experiment. All runs will be grouped under this name
mlflow.set_experiment(config["main"]["experiment_name"])
# Steps to execute
steps_par = config['main']['steps']
@ -43,37 +42,92 @@ def go(config: DictConfig):
"artifact_description": "Raw file as downloaded"
},
)
if "etl_chromdb_pdf" in active_steps:
if "etl_chromadb_pdf" in active_steps:
if config["etl"]["run_id_documents"] == "None":
# Look for run_id that has artifact logged as documents
run_id = None
client = mlflow.tracking.MlflowClient()
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
for artifact in client.list_artifacts(run.info.run_id):
if artifact.path == "documents":
run_id = run.info.run_id
break
if run_id:
break
if run_id is None:
raise ValueError("No run_id found with artifact logged as documents")
else:
run_id = config["etl"]["run_id_documents"]
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromdb_pdf"),
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromadb_pdf"),
"main",
parameters={
"input_artifact": f'{config["etl"]["input_artifact_name"]}:latest',
"output_artifact": "chromdb.zip",
"output_type": "chromdb",
"input_artifact": f'runs:/{run_id}/documents/documents.zip',
"output_artifact": "chromadb",
"output_type": "chromadb",
"output_description": "Documents in pdf to be read and stored in chromdb",
"embedding_model": config["etl"]["embedding_model"]
},
)
if "etl_chromdb_scanned_pdf" in active_steps:
if "etl_chromadb_scanned_pdf" in active_steps:
if config["etl"]["run_id_documents"] == "None":
# Look for run_id that has artifact logged as documents
run_id = None
client = mlflow.tracking.MlflowClient()
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
for artifact in client.list_artifacts(run.info.run_id):
if artifact.path == "documents":
run_id = run.info.run_id
break
if run_id:
break
if run_id is None:
raise ValueError("No run_id found with artifact logged as documents")
else:
run_id = config["etl"]["run_id_documents"]
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromdb_scanned_pdf"),
os.path.join(hydra.utils.get_original_cwd(), "src", "etl_chromadb_scanned_pdf"),
"main",
parameters={
"input_artifact": f'{config["etl"]["input_artifact_name"]}:latest',
"output_artifact": "chromdb.zip",
"output_type": "chromdb",
"input_artifact": f'runs:/{run_id}/documents/documents.zip',
"output_artifact": "chromadb",
"output_type": "chromadb",
"output_description": "Scanned Documents in pdf to be read and stored in chromdb",
"embedding_model": config["etl"]["embedding_model"]
},
)
if "chain_of_thought" in active_steps:
if "rag_cot" in active_steps:
if config["prompt_engineering"]["run_id_chromadb"] == "None":
# Look for run_id that has artifact logged as documents
run_id = None
client = mlflow.tracking.MlflowClient()
for run in client.search_runs(experiment_ids=[client.get_experiment_by_name(config["main"]["experiment_name"]).experiment_id]):
for artifact in client.list_artifacts(run.info.run_id):
if artifact.path == "chromadb":
run_id = run.info.run_id
break
if run_id:
break
if run_id is None:
raise ValueError("No run_id found with artifact logged as documents")
else:
run_id = config["etl"]["run_id_documents"]
_ = mlflow.run(
os.path.join(hydra.utils.get_original_cwd(), "src", "chain_of_thought"),
os.path.join(hydra.utils.get_original_cwd(), "src", "rag_cot"),
"main",
parameters={
"query": config["prompt_engineering"]["query"],
"input_chromadb_artifact": "chromdb.zip:latest",
"input_chromadb_artifact": f'runs:/{run_id}/chromadb/chroma_db.zip',
"embedding_model": config["etl"]["embedding_model"],
"chat_model_provider": config["prompt_engineering"]["chat_model_provider"]
},

View File

@ -1,144 +0,0 @@
import os
import logging
import argparse
import wandb
import chromadb
import shutil
from decouple import config
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
def go(args):
run = wandb.init(job_type="chain_of_thought", entity='aimingmed')
run.config.update(args)
logger.info("Downloading chromadb artifact")
artifact_chromadb_local_path = run.use_artifact(args.input_chromadb_artifact).file()
# unzip the artifact
logger.info("Unzipping the artifact")
shutil.unpack_archive(artifact_chromadb_local_path, "chroma_db")
# Load data from ChromaDB
db_folder = "chroma_db"
db_path = os.path.join(os.getcwd(), db_folder)
chroma_client = chromadb.PersistentClient(path=db_path)
collection_name = "rag_experiment"
collection = chroma_client.get_collection(name=collection_name)
# Formulate a question
question = args.query
if args.chat_model_provider == "deepseek":
# Initialize DeepSeek model
llm = ChatDeepSeek(
model="deepseek-chat",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
api_key=DEEKSEEK_API_KEY
)
elif args.chat_model_provider == "gemini":
# Initialize Gemini model
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0,
max_retries=3
)
elif args.chat_model_provider == "moonshot":
# Initialize Moonshot model
llm = Moonshot(
model="moonshot-v1-128k",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
api_key=MOONSHOT_API_KEY
)
# Chain of Thought Prompt
cot_template = """Let's think step by step.
Given the following document in text: {documents_text}
Question: {question}
Reply with language that is similar to the language used with asked question.
"""
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
cot_chain = cot_prompt | llm
# Initialize embedding model (do this ONCE)
model = SentenceTransformer(args.embedding_model)
# Query (prompt)
query_embedding = model.encode(question) # Embed the query using the SAME model
# Search ChromaDB
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
# Generate chain of thought
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
print("Chain of Thought: ", cot_output)
# Answer Prompt
answer_template = """Given the chain of thought: {cot}
Provide a concise answer to the question: {question}
Provide the answer with language that is similar to the question asked.
"""
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
answer_chain = answer_prompt | llm
# Generate answer
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
print("Answer: ", answer_output)
run.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
parser.add_argument(
"--query",
type=str,
help="Question to ask the model",
required=True
)
parser.add_argument(
"--input_chromadb_artifact",
type=str,
help="Fully-qualified name for the chromadb artifact",
required=True
)
parser.add_argument(
"--embedding_model",
type=str,
default="paraphrase-multilingual-mpnet-base-v2",
help="Sentence Transformer model name"
)
parser.add_argument(
"--chat_model_provider",
type=str,
default="gemini",
help="Chat model provider"
)
args = parser.parse_args()
go(args)

View File

@ -1,4 +1,4 @@
name: etl_chromdb_pdf
name: etl_chromadb_pdf
python_env: python_env.yml
entry_points:

View File

@ -12,5 +12,4 @@ build_dependencies:
- sentence_transformers
# Dependencies required to run the project.
dependencies:
- mlflow==2.8.1
- wandb==0.16.0
- mlflow==2.8.1

View File

@ -0,0 +1,178 @@
#!/usr/bin/env python
"""
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
"""
import argparse
import logging
import os
import mlflow
import shutil
import chromadb
import io
from pdfminer.converter import TextConverter
from pdfminer.pdfinterp import PDFPageInterpreter
from pdfminer.pdfinterp import PDFResourceManager
from pdfminer.pdfpage import PDFPage
from langchain.schema import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def extract_chinese_text_from_pdf(pdf_path):
"""
Extracts Chinese text from a PDF file.
Args:
pdf_path (str): The path to the PDF file.
Returns:
str: The extracted Chinese text, or None if an error occurs.
"""
resource_manager = PDFResourceManager()
fake_file_handle = io.StringIO()
converter = TextConverter(resource_manager, fake_file_handle)
page_interpreter = PDFPageInterpreter(resource_manager, converter)
try:
with open(pdf_path, 'rb') as fh:
for page in PDFPage.get_pages(fh, caching=True, check_extractable=True):
page_interpreter.process_page(page)
text = fake_file_handle.getvalue()
return text
except FileNotFoundError:
print(f"Error: PDF file not found at {pdf_path}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
finally:
converter.close()
fake_file_handle.close()
def go(args):
"""
Run the etl for chromdb with scanned pdf
"""
# Start an MLflow run
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
if 'output_description' not in existing_params:
mlflow.log_param('output_description', args.output_description)
# Log parameters to MLflow
mlflow.log_params({
"input_artifact": args.input_artifact,
"output_artifact": args.output_artifact,
"output_type": args.output_type,
"embedding_model": args.embedding_model
})
# Initialize embedding model (do this ONCE)
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
# Create database, delete the database directory if it exists
db_folder = "chroma_db"
db_path = os.path.join(os.getcwd(), db_folder)
if os.path.exists(db_path):
shutil.rmtree(db_path)
os.makedirs(db_path)
chroma_client = chromadb.PersistentClient(path=db_path)
collection_name = "rag_experiment"
db = chroma_client.create_collection(name=collection_name)
logger.info("Downloading artifact")
artifact_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_artifact)
logger.info("Reading data")
# unzip the downloaded artifact
import zipfile
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
zip_ref.extractall(".")
# show the unzipped folder
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
for root, _dir, files in os.walk(f"./{documents_folder}"):
for file in files:
if file.endswith(".pdf"):
read_text = extract_chinese_text_from_pdf(os.path.join(root, file))
document = Document(page_content=read_text)
all_splits = text_splitter.split_documents([document])
for i, split in enumerate(all_splits):
db.add(documents=[split.page_content],
metadatas=[{"filename": file}],
ids=[f'{file[:-4]}-{str(i)}'],
embeddings=[model_embedding.encode(split.page_content)]
)
logger.info("Logging artifact with mlflow")
shutil.make_archive(db_path, 'zip', db_path)
mlflow.log_artifact(db_path + '.zip', args.output_artifact)
# clean up
os.remove(db_path + '.zip')
shutil.rmtree(db_path)
shutil.rmtree(documents_folder)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="A very basic data cleaning")
parser.add_argument(
"--input_artifact",
type=str,
help="Fully-qualified name for the input artifact",
required=True
)
parser.add_argument(
"--output_artifact",
type=str,
help="Name for the output artifact",
required=True
)
parser.add_argument(
"--output_type",
type=str,
help="Type for the artifact output",
required=True
)
parser.add_argument(
"--output_description",
type=str,
help="Description for the artifact",
required=True
)
parser.add_argument(
"--embedding_model",
type=str,
default="paraphrase-multilingual-mpnet-base-v2",
help="Sentence Transformer model name"
)
args = parser.parse_args()
go(args)

View File

@ -1,4 +1,4 @@
name: etl_chromdb_scanned_pdf
name: etl_chromadb_scanned_pdf
python_env: python_env.yml
entry_points:

View File

@ -14,4 +14,3 @@ build_dependencies:
# Dependencies required to run the project.
dependencies:
- mlflow==2.8.1
- wandb==0.16.0

View File

@ -0,0 +1,162 @@
#!/usr/bin/env python
"""
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
"""
import argparse
import logging
import os
import mlflow
import shutil
import chromadb
# from openai import OpenAI
from typing import List
import numpy as np
import pytesseract as pt
from pdf2image import convert_from_path
from langchain.schema import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def extract_text_from_pdf_ocr(pdf_path):
try:
images = convert_from_path(pdf_path) # Convert PDF pages to images
extracted_text = ""
for image in images:
text = pt.image_to_string(image, lang="chi_sim+eng") # chi_sim for Simplified Chinese, chi_tra for Traditional
extracted_text += text + "\n"
return extracted_text
except ImportError:
print("Error: pdf2image or pytesseract not installed. Please install them: pip install pdf2image pytesseract")
return ""
except Exception as e:
print(f"OCR failed: {e}")
return ""
def go(args):
"""
Run the etl for chromdb with scanned pdf
"""
# Start an MLflow run
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
if 'output_description' not in existing_params:
mlflow.log_param('output_description', args.output_description)
# Log parameters to MLflow
mlflow.log_params({
"input_artifact": args.input_artifact,
"output_artifact": args.output_artifact,
"output_type": args.output_type,
"embedding_model": args.embedding_model
})
# Initialize embedding model
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
# Create database, delete the database directory if it exists
db_folder = "chroma_db"
db_path = os.path.join(os.getcwd(), db_folder)
if os.path.exists(db_path):
shutil.rmtree(db_path)
os.makedirs(db_path)
chroma_client = chromadb.PersistentClient(path=db_path)
collection_name = "rag_experiment"
db = chroma_client.create_collection(name=collection_name)
logger.info("Downloading artifact")
artifact_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_artifact)
logger.info("Reading data")
# unzip the downloaded artifact
import zipfile
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
zip_ref.extractall(".")
# show the unzipped folder
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
for root, _dir, files in os.walk(f"./{documents_folder}"):
for file in files:
if file.endswith(".pdf"):
read_text = extract_text_from_pdf_ocr(os.path.join(root, file))
document = Document(page_content=read_text)
all_splits = text_splitter.split_documents([document])
for i, split in enumerate(all_splits):
db.add(documents=[split.page_content],
metadatas=[{"filename": file}],
ids=[f'{file[:-4]}-{str(i)}'],
embeddings=[model_embedding.encode(split.page_content)]
)
logger.info("Uploading artifact to MLFlow")
shutil.make_archive(db_path, 'zip', db_path)
mlflow.log_artifact(db_path + '.zip', args.output_artifact)
# clean up
os.remove(db_path + '.zip')
shutil.rmtree(db_path)
shutil.rmtree(documents_folder)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="A very basic data cleaning")
parser.add_argument(
"--input_artifact",
type=str,
help="Fully-qualified name for the input artifact",
required=True
)
parser.add_argument(
"--output_artifact",
type=str,
help="Name for the output artifact",
required=True
)
parser.add_argument(
"--output_type",
type=str,
help="Type for the artifact output",
required=True
)
parser.add_argument(
"--output_description",
type=str,
help="Description for the artifact",
required=True
)
parser.add_argument(
"--embedding_model",
type=str,
default="paraphrase-multilingual-mpnet-base-v2",
help="Sentence Transformer model name"
)
args = parser.parse_args()
go(args)

View File

@ -1,184 +0,0 @@
#!/usr/bin/env python
"""
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
"""
import argparse
import logging
import os
import wandb
import shutil
import chromadb
# from openai import OpenAI
import io
from pdfminer.converter import TextConverter
from pdfminer.pdfinterp import PDFPageInterpreter
from pdfminer.pdfinterp import PDFResourceManager
from pdfminer.pdfpage import PDFPage
from langchain.schema import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def extract_chinese_text_from_pdf(pdf_path):
"""
Extracts Chinese text from a PDF file.
Args:
pdf_path (str): The path to the PDF file.
Returns:
str: The extracted Chinese text, or None if an error occurs.
"""
resource_manager = PDFResourceManager()
fake_file_handle = io.StringIO()
converter = TextConverter(resource_manager, fake_file_handle)
page_interpreter = PDFPageInterpreter(resource_manager, converter)
try:
with open(pdf_path, 'rb') as fh:
for page in PDFPage.get_pages(fh, caching=True, check_extractable=True):
page_interpreter.process_page(page)
text = fake_file_handle.getvalue()
return text
except FileNotFoundError:
print(f"Error: PDF file not found at {pdf_path}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
finally:
converter.close()
fake_file_handle.close()
def go(args):
"""
Run the etl for chromdb with scanned pdf
"""
run = wandb.init(job_type="etl_chromdb_scanned_pdf", entity='aimingmed')
run.config.update(args)
# Initialize embedding model (do this ONCE)
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
# Create database, delete the database directory if it exists
db_folder = "chroma_db"
db_path = os.path.join(os.getcwd(), db_folder)
if os.path.exists(db_path):
shutil.rmtree(db_path)
os.makedirs(db_path)
chroma_client = chromadb.PersistentClient(path=db_path)
collection_name = "rag_experiment"
db = chroma_client.create_collection(name=collection_name)
logger.info("Downloading artifact")
artifact_local_path = run.use_artifact(args.input_artifact).file()
logger.info("Reading data")
# unzip the downloaded artifact
import zipfile
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
zip_ref.extractall(".")
os.remove(artifact_local_path)
# show the unzipped folder
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
for root, _dir, files in os.walk(f"./{documents_folder}"):
for file in files:
if file.endswith(".pdf"):
read_text = extract_chinese_text_from_pdf(os.path.join(root, file))
document = Document(page_content=read_text)
all_splits = text_splitter.split_documents([document])
for i, split in enumerate(all_splits):
db.add(documents=[split.page_content],
metadatas=[{"filename": file}],
ids=[f'{file[:-4]}-{str(i)}'],
embeddings=[model_embedding.encode(split.page_content)]
)
# Create a new artifact
artifact = wandb.Artifact(
args.output_artifact,
type=args.output_type,
description=args.output_description
)
# zip the database folder first
shutil.make_archive(db_path, 'zip', db_path)
# Add the database to the artifact
artifact.add_file(db_path + '.zip')
# Log the artifact
run.log_artifact(artifact)
# Finish the run
run.finish()
# clean up
os.remove(db_path + '.zip')
os.remove(db_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="A very basic data cleaning")
parser.add_argument(
"--input_artifact",
type=str,
help="Fully-qualified name for the input artifact",
required=True
)
parser.add_argument(
"--output_artifact",
type=str,
help="Name for the output artifact",
required=True
)
parser.add_argument(
"--output_type",
type=str,
help="Type for the artifact output",
required=True
)
parser.add_argument(
"--output_description",
type=str,
help="Description for the artifact",
required=True
)
parser.add_argument(
"--embedding_model",
type=str,
default="paraphrase-multilingual-mpnet-base-v2",
help="Sentence Transformer model name"
)
args = parser.parse_args()
go(args)

View File

@ -1,173 +0,0 @@
#!/usr/bin/env python
"""
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
"""
import argparse
import logging
import os
import wandb
import shutil
import chromadb
# from openai import OpenAI
from typing import List
import numpy as np
import pytesseract as pt
from pdf2image import convert_from_path
from langchain.schema import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def extract_text_from_pdf_ocr(pdf_path):
try:
images = convert_from_path(pdf_path) # Convert PDF pages to images
extracted_text = ""
for image in images:
text = pt.image_to_string(image, lang="chi_sim+eng") # chi_sim for Simplified Chinese, chi_tra for Traditional
extracted_text += text + "\n"
return extracted_text
except ImportError:
print("Error: pdf2image or pytesseract not installed. Please install them: pip install pdf2image pytesseract")
return ""
except Exception as e:
print(f"OCR failed: {e}")
return ""
def go(args):
"""
Run the etl for chromdb with scanned pdf
"""
run = wandb.init(job_type="etl_chromdb_scanned_pdf", entity='aimingmed')
run.config.update(args)
# Setup the Gemini client
# client = OpenAI(
# api_key=args.gemini_api_key,
# base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
# )
# def get_google_embedding(text: str) -> List[float]:
# response = client.embeddings.create(
# model="text-embedding-004",
# input=text
# )
# return response.data[0].embedding
# class GeminiEmbeddingFunction(object):
# def __init__(self, api_key: str, base_url: str, model_name: str):
# self.client = OpenAI(
# api_key=args.gemini_api_key,
# base_url=base_url
# )
# self.model_name = model_name
# def __call__(self, input: List[str]) -> List[List[float]]:
# all_embeddings = []
# for text in input:
# response = self.client.embeddings.create(input=text, model=self.model_name)
# embeddings = [record.embedding for record in response.data]
# all_embeddings.append(np.array(embeddings[0]))
# return all_embeddings
# Initialize embedding model (do this ONCE)
model_embedding = SentenceTransformer('all-mpnet-base-v2') # Or a multilingual model
# Create database, delete the database directory if it exists
db_folder = "chroma_db"
db_path = os.path.join(os.getcwd(), db_folder)
if os.path.exists(db_path):
shutil.rmtree(db_path)
os.makedirs(db_path)
chroma_client = chromadb.PersistentClient(path=db_path)
collection_name = "rag_experiment"
db = chroma_client.create_collection(name=collection_name)
logger.info("Downloading artifact")
artifact_local_path = run.use_artifact(args.input_artifact).file()
logger.info("Reading data")
# unzip the downloaded artifact
import zipfile
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
zip_ref.extractall(".")
os.remove(artifact_local_path)
# show the unzipped folder
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
for root, _dir, files in os.walk(f"./{documents_folder}"):
for file in files:
if file.endswith(".pdf"):
read_text = extract_text_from_pdf_ocr(os.path.join(root, file))
document = Document(page_content=read_text)
all_splits = text_splitter.split_documents([document])
for i, split in enumerate(all_splits):
db.add(documents=[split.page_content],
metadatas=[{"filename": file}],
ids=[f'{file[:-4]}-{str(i)}'],
embeddings=[model_embedding.encode(split.page_content)]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="A very basic data cleaning")
parser.add_argument(
"--input_artifact",
type=str,
help="Fully-qualified name for the input artifact",
required=True
)
parser.add_argument(
"--output_artifact",
type=str,
help="Name for the output artifact",
required=True
)
parser.add_argument(
"--output_type",
type=str,
help="Type for the artifact output",
required=True
)
parser.add_argument(
"--output_description",
type=str,
help="Description for the artifact",
required=True
)
parser.add_argument(
"--embedding_model",
type=str,
default="paraphrase-multilingual-mpnet-base-v2",
help="Sentence Transformer model name"
)
args = parser.parse_args()
go(args)

View File

@ -1,4 +1,4 @@
name: chain_of_thought
name: rag_cot
python_env: python_env.yml
entry_points:

View File

@ -14,5 +14,4 @@ build_dependencies:
- langchain-community
# Dependencies required to run the project.
dependencies:
- mlflow==2.8.1
- wandb==0.16.0
- mlflow==2.8.1

View File

@ -0,0 +1,155 @@
import os
import logging
import argparse
import mlflow
import chromadb
import shutil
from decouple import config
from langchain.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_deepseek import ChatDeepSeek
from langchain_community.llms.moonshot import Moonshot
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
GEMINI_API_KEY = config("GOOGLE_API_KEY", cast=str)
DEEKSEEK_API_KEY = config("DEEKSEEK_API_KEY", cast=str)
MOONSHOT_API_KEY = config("MOONSHOT_API_KEY", cast=str)
def go(args):
# start a new MLflow run
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
if 'query' not in existing_params:
mlflow.log_param('query', args.query)
# Log parameters to MLflow
mlflow.log_params({
"input_chromadb_artifact": args.input_chromadb_artifact,
"embedding_model": args.embedding_model,
"chat_model_provider": args.chat_model_provider
})
logger.info("Downloading chromadb artifact")
artifact_chromadb_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_chromadb_artifact)
# unzip the artifact
logger.info("Unzipping the artifact")
shutil.unpack_archive(artifact_chromadb_local_path, "chroma_db")
# Load data from ChromaDB
db_folder = "chroma_db"
db_path = os.path.join(os.getcwd(), db_folder)
chroma_client = chromadb.PersistentClient(path=db_path)
collection_name = "rag_experiment"
collection = chroma_client.get_collection(name=collection_name)
# Formulate a question
question = args.query
if args.chat_model_provider == "deepseek":
# Initialize DeepSeek model
llm = ChatDeepSeek(
model="deepseek-chat",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
api_key=DEEKSEEK_API_KEY
)
elif args.chat_model_provider == "gemini":
# Initialize Gemini model
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
google_api_key=GEMINI_API_KEY,
temperature=0,
max_retries=3
)
elif args.chat_model_provider == "moonshot":
# Initialize Moonshot model
llm = Moonshot(
model="moonshot-v1-128k",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
api_key=MOONSHOT_API_KEY
)
# Chain of Thought Prompt
cot_template = """Let's think step by step.
Given the following document in text: {documents_text}
Question: {question}
Reply with language that is similar to the language used with asked question.
"""
cot_prompt = PromptTemplate(template=cot_template, input_variables=["documents_text", "question"])
cot_chain = cot_prompt | llm
# Initialize embedding model (do this ONCE)
model = SentenceTransformer(args.embedding_model)
# Query (prompt)
query_embedding = model.encode(question) # Embed the query using the SAME model
# Search ChromaDB
documents_text = collection.query(query_embeddings=[query_embedding], n_results=5)
# Generate chain of thought
cot_output = cot_chain.invoke({"documents_text": documents_text, "question": question})
print("Chain of Thought: ", cot_output)
# Answer Prompt
answer_template = """Given the chain of thought: {cot}
Provide a concise answer to the question: {question}
Provide the answer with language that is similar to the question asked.
"""
answer_prompt = PromptTemplate(template=answer_template, input_variables=["cot", "question"])
answer_chain = answer_prompt | llm
# Generate answer
answer_output = answer_chain.invoke({"cot": cot_output, "question": question})
print("Answer: ", answer_output)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chain of Thought RAG")
parser.add_argument(
"--query",
type=str,
help="Question to ask the model",
required=True
)
parser.add_argument(
"--input_chromadb_artifact",
type=str,
help="Fully-qualified name for the chromadb artifact",
required=True
)
parser.add_argument(
"--embedding_model",
type=str,
default="paraphrase-multilingual-mpnet-base-v2",
help="Sentence Transformer model name"
)
parser.add_argument(
"--chat_model_provider",
type=str,
default="gemini",
help="Chat model provider"
)
args = parser.parse_args()
go(args)