This commit is contained in:
leehk 2025-02-22 15:35:05 +08:00
parent 702b1eb874
commit 82435ebdbd
5 changed files with 46 additions and 38 deletions

View File

@ -6,4 +6,5 @@ main:
etl:
document_folder: "documents"
path_document_folder: "../../../../data"
embedding_model: "paraphrase-multilingual-mpnet-base-v2"

View File

@ -64,7 +64,7 @@ def go(config: DictConfig):
"output_artifact": "chromdb.zip",
"output_type": "chromdb",
"output_description": "Scanned Documents in pdf to be read amd stored in chromdb",
"gemini_api_key": GEMINI_API_KEY,
"embedding_model": config["etl"]["embedding_model"]
},
)

View File

@ -21,8 +21,8 @@ entry_points:
description: Description for the artifact
type: string
gemini_api_key:
description: API key for Gemini
embedding_model:
description: Fully-qualified name for the embedding model
type: string
@ -31,4 +31,4 @@ entry_points:
--output_artifact {output_artifact} \
--output_type {output_type} \
--output_description {output_description} \
--gemini_api_key {gemini_api_key}
--embedding_model {embedding_model}

View File

@ -10,6 +10,7 @@ build_dependencies:
- pytesseract
- pdf2image
- langchain
- sentence_transformers
# Dependencies required to run the project.
dependencies:
- mlflow==2.8.1

View File

@ -9,7 +9,7 @@ import wandb
import shutil
import chromadb
from openai import OpenAI
# from openai import OpenAI
from typing import List
import numpy as np
import pytesseract as pt
@ -17,10 +17,12 @@ from pdf2image import convert_from_path
from langchain.schema import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
logger = logging.getLogger()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def extract_text_from_pdf_ocr(pdf_path):
try:
@ -50,41 +52,46 @@ def go(args):
run.config.update(args)
# Setup the Gemini client
client = OpenAI(
api_key=args.gemini_api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)
# client = OpenAI(
# api_key=args.gemini_api_key,
# base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
# )
def get_google_embedding(text: str) -> List[float]:
response = client.embeddings.create(
model="text-embedding-004",
input=text
)
return response.data[0].embedding
# def get_google_embedding(text: str) -> List[float]:
# response = client.embeddings.create(
# model="text-embedding-004",
# input=text
# )
# return response.data[0].embedding
class GeminiEmbeddingFunction(object):
def __init__(self, api_key: str, base_url: str, model_name: str):
self.client = OpenAI(
api_key=args.gemini_api_key,
base_url=base_url
)
self.model_name = model_name
# class GeminiEmbeddingFunction(object):
# def __init__(self, api_key: str, base_url: str, model_name: str):
# self.client = OpenAI(
# api_key=args.gemini_api_key,
# base_url=base_url
# )
# self.model_name = model_name
# def __call__(self, input: List[str]) -> List[List[float]]:
# all_embeddings = []
# for text in input:
# response = self.client.embeddings.create(input=text, model=self.model_name)
# embeddings = [record.embedding for record in response.data]
# all_embeddings.append(np.array(embeddings[0]))
# return all_embeddings
class SentenceTransformerEmbeddingFunction(object):
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name)
def __call__(self, input: List[str]) -> List[List[float]]:
all_embeddings = []
for text in input:
response = self.client.embeddings.create(input=text, model=self.model_name)
embeddings = [record.embedding for record in response.data]
all_embeddings.append(np.array(embeddings[0]))
return all_embeddings
embeddings = self.model.encode(input)
return embeddings.tolist()
# Define embedding function
gemini_ef = GeminiEmbeddingFunction(
api_key=args.gemini_api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
model_name="text-embedding-004"
)
model_name = 'paraphrase-multilingual-mpnet-base-v2'
embedding_function = SentenceTransformerEmbeddingFunction(model_name)
# Create database, delete the database directory if it exists
db_folder = "chroma_db"
@ -95,7 +102,7 @@ def go(args):
chroma_client = chromadb.PersistentClient(path=db_path)
collection_name = "rag_experiment"
db = chroma_client.create_collection(name=collection_name, embedding_function=gemini_ef)
db = chroma_client.create_collection(name=collection_name, embedding_function=embedding_function)
logger.info("Downloading artifact")
@ -126,7 +133,6 @@ def go(args):
metadatas=[{"filename": file}],
ids=[file[:-4] + str(i)])
if __name__ == "__main__":
@ -161,10 +167,10 @@ if __name__ == "__main__":
)
parser.add_argument(
"--gemini_api_key",
"--embedding_model",
type=str,
help="API key for the Gemini service",
required=True
default="paraphrase-multilingual-mpnet-base-v2",
help="Sentence Transformer model name"
)