mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-29 02:23:33 +08:00
136 lines
4.6 KiB
Python
136 lines
4.6 KiB
Python
#!/usr/bin/env python
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import wandb
|
|
import shutil
|
|
import chromadb
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
|
logger = logging.getLogger()
|
|
|
|
def combine_chromadb(chromadb_pdf_path, chromadb_scanned_pdf_path, output_path):
|
|
"""
|
|
Combines two ChromaDB instances into a single ChromaDB.
|
|
"""
|
|
|
|
# Load the ChromaDB instances
|
|
chromadb_pdf_client = chromadb.PersistentClient(path=chromadb_pdf_path)
|
|
chromadb_scanned_pdf_client = chromadb.PersistentClient(path=chromadb_scanned_pdf_path)
|
|
|
|
# Get the collections
|
|
collection_name = "rag_experiment"
|
|
try:
|
|
chromadb_pdf_collection = chromadb_pdf_client.get_collection(name=collection_name)
|
|
except ValueError as e:
|
|
raise ValueError(f"Collection '{collection_name}' not found in ChromaDB at '{chromadb_pdf_path}'. Ensure the etl_chromdb_pdf step was run successfully.") from e
|
|
try:
|
|
chromadb_scanned_pdf_collection = chromadb_scanned_pdf_client.get_collection(name=collection_name)
|
|
except ValueError as e:
|
|
raise ValueError(f"Collection '{collection_name}' not found in ChromaDB at '{chromadb_scanned_pdf_path}'. Ensure the etl_chromdb_scanned_pdf step was run successfully.") from e
|
|
|
|
# Get all data from the collections
|
|
chromadb_pdf_data = chromadb_pdf_collection.get(include=["documents", "metadatas", "embeddings"])
|
|
chromadb_scanned_pdf_data = chromadb_scanned_pdf_collection.get(include=["documents", "metadatas", "embeddings"])
|
|
|
|
# Create a new ChromaDB instance
|
|
combined_chromadb_client = chromadb.PersistentClient(path=output_path)
|
|
combined_chromadb_collection = combined_chromadb_client.create_collection(name=collection_name)
|
|
|
|
# Add the data to the combined ChromaDB
|
|
combined_chromadb_collection.add(
|
|
documents=chromadb_pdf_data["documents"] + chromadb_scanned_pdf_data["documents"],
|
|
metadatas=chromadb_pdf_data["metadatas"] + chromadb_scanned_pdf_data["metadatas"],
|
|
ids=chromadb_pdf_data["ids"] + chromadb_scanned_pdf_data["ids"],
|
|
embeddings=chromadb_pdf_data["embeddings"] + chromadb_scanned_pdf_data["embeddings"],
|
|
)
|
|
|
|
logger.info(f"Combined ChromaDB created at {output_path}")
|
|
|
|
|
|
def go(args):
|
|
"""
|
|
Run the combine chromadb component.
|
|
"""
|
|
|
|
run = wandb.init(job_type="combine_chromadb", entity='aimingmed')
|
|
run.config.update(args)
|
|
|
|
# Download the ChromaDB artifacts
|
|
logger.info("Downloading chromadb_pdf artifact")
|
|
chromadb_pdf_artifact = run.use_artifact(args.chromadb_pdf_artifact).file()
|
|
chromadb_pdf_path = os.path.join(chromadb_pdf_artifact, "chroma_db")
|
|
|
|
logger.info("Downloading chromadb_scanned_pdf artifact")
|
|
chromadb_scanned_pdf_artifact = run.use_artifact(args.chromadb_scanned_pdf_artifact).file()
|
|
chromadb_scanned_pdf_path = os.path.join(chromadb_scanned_pdf_artifact, "chroma_db")
|
|
|
|
# Create the output directory
|
|
output_folder = "combined_chromadb"
|
|
output_path = os.path.join(os.getcwd(), output_folder)
|
|
if os.path.exists(output_path):
|
|
shutil.rmtree(output_path)
|
|
os.makedirs(output_path)
|
|
|
|
# Combine the ChromaDB instances
|
|
combine_chromadb(chromadb_pdf_path, chromadb_scanned_pdf_path, output_path)
|
|
|
|
# Create a new artifact
|
|
artifact = wandb.Artifact(
|
|
args.output_artifact,
|
|
type=args.output_type,
|
|
description=args.output_description
|
|
)
|
|
|
|
# Zip the database folder first
|
|
shutil.make_archive(output_path, 'zip', output_path)
|
|
|
|
# Add the database to the artifact
|
|
artifact.add_file(output_path + '.zip')
|
|
|
|
# Log the artifact
|
|
run.log_artifact(artifact)
|
|
|
|
# Finish the run
|
|
run.finish()
|
|
|
|
# clean up - remove zip
|
|
os.remove(output_path + '.zip')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Combine two ChromaDB instances into one.")
|
|
|
|
parser.add_argument(
|
|
"--chromadb_pdf_artifact",
|
|
type=str,
|
|
required=True,
|
|
help="Fully-qualified name for the ChromaDB PDF artifact",
|
|
)
|
|
parser.add_argument(
|
|
"--chromadb_scanned_pdf_artifact",
|
|
type=str,
|
|
required=True,
|
|
help="Fully-qualified name for the ChromaDB Scanned PDF artifact",
|
|
)
|
|
parser.add_argument(
|
|
"--output_artifact",
|
|
type=str,
|
|
required=True,
|
|
help="Name for the output artifact",
|
|
)
|
|
parser.add_argument(
|
|
"--output_type",
|
|
type=str,
|
|
required=True,
|
|
help="Type for the output artifact",
|
|
)
|
|
parser.add_argument(
|
|
"--output_description",
|
|
type=str,
|
|
required=True,
|
|
help="Description for the output artifact",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
go(args) |