mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-06 15:26:55 +08:00
162 lines
5.0 KiB
Python
162 lines
5.0 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Download from W&B the raw dataset and apply some basic data cleaning, exporting the result to a new artifact
|
|
"""
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import mlflow
|
|
import shutil
|
|
|
|
import chromadb
|
|
# from openai import OpenAI
|
|
from typing import List
|
|
import numpy as np
|
|
import pytesseract as pt
|
|
from pdf2image import convert_from_path
|
|
from langchain.schema import Document
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
|
logger = logging.getLogger()
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
def extract_text_from_pdf_ocr(pdf_path):
|
|
try:
|
|
images = convert_from_path(pdf_path) # Convert PDF pages to images
|
|
extracted_text = ""
|
|
for image in images:
|
|
text = pt.image_to_string(image, lang="chi_sim+eng") # chi_sim for Simplified Chinese, chi_tra for Traditional
|
|
|
|
extracted_text += text + "\n"
|
|
return extracted_text
|
|
|
|
except ImportError:
|
|
print("Error: pdf2image or pytesseract not installed. Please install them: pip install pdf2image pytesseract")
|
|
return ""
|
|
except Exception as e:
|
|
print(f"OCR failed: {e}")
|
|
return ""
|
|
|
|
|
|
|
|
def go(args):
|
|
"""
|
|
Run the etl for chromdb with scanned pdf
|
|
"""
|
|
|
|
# Start an MLflow run
|
|
with mlflow.start_run(experiment_id=mlflow.get_experiment_by_name("development").experiment_id, run_name="etl_chromdb_pdf"):
|
|
existing_params = mlflow.get_run(mlflow.active_run().info.run_id).data.params
|
|
if 'output_description' not in existing_params:
|
|
mlflow.log_param('output_description', args.output_description)
|
|
|
|
# Log parameters to MLflow
|
|
mlflow.log_params({
|
|
"input_artifact": args.input_artifact,
|
|
"output_artifact": args.output_artifact,
|
|
"output_type": args.output_type,
|
|
"embedding_model": args.embedding_model
|
|
})
|
|
|
|
|
|
# Initialize embedding model
|
|
model_embedding = SentenceTransformer(args.embedding_model) # Or a multilingual model
|
|
|
|
|
|
# Create database, delete the database directory if it exists
|
|
db_folder = "chroma_db"
|
|
db_path = os.path.join(os.getcwd(), db_folder)
|
|
if os.path.exists(db_path):
|
|
shutil.rmtree(db_path)
|
|
os.makedirs(db_path)
|
|
|
|
chroma_client = chromadb.PersistentClient(path=db_path)
|
|
collection_name = "rag_experiment"
|
|
db = chroma_client.create_collection(name=collection_name)
|
|
|
|
|
|
logger.info("Downloading artifact")
|
|
artifact_local_path = mlflow.artifacts.download_artifacts(artifact_uri=args.input_artifact)
|
|
|
|
logger.info("Reading data")
|
|
|
|
# unzip the downloaded artifact
|
|
import zipfile
|
|
with zipfile.ZipFile(artifact_local_path, 'r') as zip_ref:
|
|
zip_ref.extractall(".")
|
|
|
|
# show the unzipped folder
|
|
documents_folder = os.path.splitext(os.path.basename(artifact_local_path))[0]
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
|
|
for root, _dir, files in os.walk(f"./{documents_folder}"):
|
|
for file in files:
|
|
if file.endswith(".pdf"):
|
|
read_text = extract_text_from_pdf_ocr(os.path.join(root, file))
|
|
document = Document(page_content=read_text)
|
|
all_splits = text_splitter.split_documents([document])
|
|
|
|
for i, split in enumerate(all_splits):
|
|
db.add(documents=[split.page_content],
|
|
metadatas=[{"filename": file}],
|
|
ids=[f'{file[:-4]}-{str(i)}'],
|
|
embeddings=[model_embedding.encode(split.page_content)]
|
|
)
|
|
|
|
logger.info("Uploading artifact to MLFlow")
|
|
shutil.make_archive(db_path, 'zip', db_path)
|
|
mlflow.log_artifact(db_path + '.zip', args.output_artifact)
|
|
|
|
# clean up
|
|
os.remove(db_path + '.zip')
|
|
shutil.rmtree(db_path)
|
|
shutil.rmtree(documents_folder)
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
|
|
|
parser.add_argument(
|
|
"--input_artifact",
|
|
type=str,
|
|
help="Fully-qualified name for the input artifact",
|
|
required=True
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output_artifact",
|
|
type=str,
|
|
help="Name for the output artifact",
|
|
required=True
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output_type",
|
|
type=str,
|
|
help="Type for the artifact output",
|
|
required=True
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output_description",
|
|
type=str,
|
|
help="Description for the artifact",
|
|
required=True
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--embedding_model",
|
|
type=str,
|
|
default="paraphrase-multilingual-mpnet-base-v2",
|
|
help="Sentence Transformer model name"
|
|
)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
go(args) |