diff --git a/app/llmops/src/etl_chromdb_scanned_pdf/run.py b/app/llmops/src/etl_chromdb_scanned_pdf/run.py index 6ddc082..edbb1fc 100644 --- a/app/llmops/src/etl_chromdb_scanned_pdf/run.py +++ b/app/llmops/src/etl_chromdb_scanned_pdf/run.py @@ -81,17 +81,10 @@ def go(args): # all_embeddings.append(np.array(embeddings[0])) # return all_embeddings - class SentenceTransformerEmbeddingFunction(object): - def __init__(self, model_name: str): - self.model = SentenceTransformer(model_name) - def __call__(self, input: List[str]) -> List[List[float]]: - embeddings = self.model.encode(input) - return embeddings.tolist() + # Initialize embedding model (do this ONCE) + model_embedding = SentenceTransformer('all-mpnet-base-v2') # Or a multilingual model - # Define embedding function - model_name = 'paraphrase-multilingual-mpnet-base-v2' - embedding_function = SentenceTransformerEmbeddingFunction(model_name) # Create database, delete the database directory if it exists db_folder = "chroma_db" @@ -102,7 +95,7 @@ def go(args): chroma_client = chromadb.PersistentClient(path=db_path) collection_name = "rag_experiment" - db = chroma_client.create_collection(name=collection_name, embedding_function=embedding_function) + db = chroma_client.create_collection(name=collection_name) logger.info("Downloading artifact") @@ -131,8 +124,9 @@ def go(args): for i, split in enumerate(all_splits): db.add(documents=[split.page_content], metadatas=[{"filename": file}], - ids=[file[:-4] + str(i)]) - + ids=[f'{file[:-4]}-{str(i)}'], + embeddings=[model_embedding.encode(split.page_content)] + ) if __name__ == "__main__":