From 82435ebdbd6acb684513972de172593645da0f0d Mon Sep 17 00:00:00 2001 From: leehk Date: Sat, 22 Feb 2025 15:35:05 +0800 Subject: [PATCH] update --- app/llmops/config.yaml | 1 + app/llmops/main.py | 2 +- .../src/etl_chromdb_scanned_pdf/MLproject | 6 +- .../etl_chromdb_scanned_pdf/python_env.yml | 1 + app/llmops/src/etl_chromdb_scanned_pdf/run.py | 74 ++++++++++--------- 5 files changed, 46 insertions(+), 38 deletions(-) diff --git a/app/llmops/config.yaml b/app/llmops/config.yaml index b69ac2d..4881487 100644 --- a/app/llmops/config.yaml +++ b/app/llmops/config.yaml @@ -6,4 +6,5 @@ main: etl: document_folder: "documents" path_document_folder: "../../../../data" + embedding_model: "paraphrase-multilingual-mpnet-base-v2" diff --git a/app/llmops/main.py b/app/llmops/main.py index a0b97f8..bac5fd2 100644 --- a/app/llmops/main.py +++ b/app/llmops/main.py @@ -64,7 +64,7 @@ def go(config: DictConfig): "output_artifact": "chromdb.zip", "output_type": "chromdb", "output_description": "Scanned Documents in pdf to be read amd stored in chromdb", - "gemini_api_key": GEMINI_API_KEY, + "embedding_model": config["etl"]["embedding_model"] }, ) diff --git a/app/llmops/src/etl_chromdb_scanned_pdf/MLproject b/app/llmops/src/etl_chromdb_scanned_pdf/MLproject index 2ca4ab1..bd47191 100644 --- a/app/llmops/src/etl_chromdb_scanned_pdf/MLproject +++ b/app/llmops/src/etl_chromdb_scanned_pdf/MLproject @@ -21,8 +21,8 @@ entry_points: description: Description for the artifact type: string - gemini_api_key: - description: API key for Gemini + embedding_model: + description: Fully-qualified name for the embedding model type: string @@ -31,4 +31,4 @@ entry_points: --output_artifact {output_artifact} \ --output_type {output_type} \ --output_description {output_description} \ - --gemini_api_key {gemini_api_key} \ No newline at end of file + --embedding_model {embedding_model} \ No newline at end of file diff --git a/app/llmops/src/etl_chromdb_scanned_pdf/python_env.yml b/app/llmops/src/etl_chromdb_scanned_pdf/python_env.yml index 5d043c3..0045f53 100644 --- a/app/llmops/src/etl_chromdb_scanned_pdf/python_env.yml +++ b/app/llmops/src/etl_chromdb_scanned_pdf/python_env.yml @@ -10,6 +10,7 @@ build_dependencies: - pytesseract - pdf2image - langchain + - sentence_transformers # Dependencies required to run the project. dependencies: - mlflow==2.8.1 diff --git a/app/llmops/src/etl_chromdb_scanned_pdf/run.py b/app/llmops/src/etl_chromdb_scanned_pdf/run.py index 76160bc..6ddc082 100644 --- a/app/llmops/src/etl_chromdb_scanned_pdf/run.py +++ b/app/llmops/src/etl_chromdb_scanned_pdf/run.py @@ -9,7 +9,7 @@ import wandb import shutil import chromadb -from openai import OpenAI +# from openai import OpenAI from typing import List import numpy as np import pytesseract as pt @@ -17,10 +17,12 @@ from pdf2image import convert_from_path from langchain.schema import Document from langchain_text_splitters import RecursiveCharacterTextSplitter +from sentence_transformers import SentenceTransformer logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s") logger = logging.getLogger() +os.environ["TOKENIZERS_PARALLELISM"] = "false" def extract_text_from_pdf_ocr(pdf_path): try: @@ -50,41 +52,46 @@ def go(args): run.config.update(args) # Setup the Gemini client - client = OpenAI( - api_key=args.gemini_api_key, - base_url="https://generativelanguage.googleapis.com/v1beta/openai/" - ) + # client = OpenAI( + # api_key=args.gemini_api_key, + # base_url="https://generativelanguage.googleapis.com/v1beta/openai/" + # ) - def get_google_embedding(text: str) -> List[float]: - response = client.embeddings.create( - model="text-embedding-004", - input=text - ) - return response.data[0].embedding + # def get_google_embedding(text: str) -> List[float]: + # response = client.embeddings.create( + # model="text-embedding-004", + # input=text + # ) + # return response.data[0].embedding - class GeminiEmbeddingFunction(object): - def __init__(self, api_key: str, base_url: str, model_name: str): - self.client = OpenAI( - api_key=args.gemini_api_key, - base_url=base_url - ) - self.model_name = model_name + # class GeminiEmbeddingFunction(object): + # def __init__(self, api_key: str, base_url: str, model_name: str): + # self.client = OpenAI( + # api_key=args.gemini_api_key, + # base_url=base_url + # ) + # self.model_name = model_name + + # def __call__(self, input: List[str]) -> List[List[float]]: + # all_embeddings = [] + # for text in input: + # response = self.client.embeddings.create(input=text, model=self.model_name) + # embeddings = [record.embedding for record in response.data] + # all_embeddings.append(np.array(embeddings[0])) + # return all_embeddings + + class SentenceTransformerEmbeddingFunction(object): + def __init__(self, model_name: str): + self.model = SentenceTransformer(model_name) def __call__(self, input: List[str]) -> List[List[float]]: - all_embeddings = [] - for text in input: - response = self.client.embeddings.create(input=text, model=self.model_name) - embeddings = [record.embedding for record in response.data] - all_embeddings.append(np.array(embeddings[0])) - return all_embeddings + embeddings = self.model.encode(input) + return embeddings.tolist() # Define embedding function - gemini_ef = GeminiEmbeddingFunction( - api_key=args.gemini_api_key, - base_url="https://generativelanguage.googleapis.com/v1beta/openai/", - model_name="text-embedding-004" - ) + model_name = 'paraphrase-multilingual-mpnet-base-v2' + embedding_function = SentenceTransformerEmbeddingFunction(model_name) # Create database, delete the database directory if it exists db_folder = "chroma_db" @@ -95,7 +102,7 @@ def go(args): chroma_client = chromadb.PersistentClient(path=db_path) collection_name = "rag_experiment" - db = chroma_client.create_collection(name=collection_name, embedding_function=gemini_ef) + db = chroma_client.create_collection(name=collection_name, embedding_function=embedding_function) logger.info("Downloading artifact") @@ -126,7 +133,6 @@ def go(args): metadatas=[{"filename": file}], ids=[file[:-4] + str(i)]) - if __name__ == "__main__": @@ -161,10 +167,10 @@ if __name__ == "__main__": ) parser.add_argument( - "--gemini_api_key", + "--embedding_model", type=str, - help="API key for the Gemini service", - required=True + default="paraphrase-multilingual-mpnet-base-v2", + help="Sentence Transformer model name" )