mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-01-19 13:23:23 +08:00
update
This commit is contained in:
parent
702b1eb874
commit
82435ebdbd
@ -6,4 +6,5 @@ main:
|
||||
etl:
|
||||
document_folder: "documents"
|
||||
path_document_folder: "../../../../data"
|
||||
embedding_model: "paraphrase-multilingual-mpnet-base-v2"
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ def go(config: DictConfig):
|
||||
"output_artifact": "chromdb.zip",
|
||||
"output_type": "chromdb",
|
||||
"output_description": "Scanned Documents in pdf to be read amd stored in chromdb",
|
||||
"gemini_api_key": GEMINI_API_KEY,
|
||||
"embedding_model": config["etl"]["embedding_model"]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -21,8 +21,8 @@ entry_points:
|
||||
description: Description for the artifact
|
||||
type: string
|
||||
|
||||
gemini_api_key:
|
||||
description: API key for Gemini
|
||||
embedding_model:
|
||||
description: Fully-qualified name for the embedding model
|
||||
type: string
|
||||
|
||||
|
||||
@ -31,4 +31,4 @@ entry_points:
|
||||
--output_artifact {output_artifact} \
|
||||
--output_type {output_type} \
|
||||
--output_description {output_description} \
|
||||
--gemini_api_key {gemini_api_key}
|
||||
--embedding_model {embedding_model}
|
||||
@ -10,6 +10,7 @@ build_dependencies:
|
||||
- pytesseract
|
||||
- pdf2image
|
||||
- langchain
|
||||
- sentence_transformers
|
||||
# Dependencies required to run the project.
|
||||
dependencies:
|
||||
- mlflow==2.8.1
|
||||
|
||||
@ -9,7 +9,7 @@ import wandb
|
||||
import shutil
|
||||
|
||||
import chromadb
|
||||
from openai import OpenAI
|
||||
# from openai import OpenAI
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import pytesseract as pt
|
||||
@ -17,10 +17,12 @@ from pdf2image import convert_from_path
|
||||
from langchain.schema import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||
logger = logging.getLogger()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def extract_text_from_pdf_ocr(pdf_path):
|
||||
try:
|
||||
@ -50,41 +52,46 @@ def go(args):
|
||||
run.config.update(args)
|
||||
|
||||
# Setup the Gemini client
|
||||
client = OpenAI(
|
||||
api_key=args.gemini_api_key,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
)
|
||||
# client = OpenAI(
|
||||
# api_key=args.gemini_api_key,
|
||||
# base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
# )
|
||||
|
||||
|
||||
def get_google_embedding(text: str) -> List[float]:
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-004",
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
# def get_google_embedding(text: str) -> List[float]:
|
||||
# response = client.embeddings.create(
|
||||
# model="text-embedding-004",
|
||||
# input=text
|
||||
# )
|
||||
# return response.data[0].embedding
|
||||
|
||||
class GeminiEmbeddingFunction(object):
|
||||
def __init__(self, api_key: str, base_url: str, model_name: str):
|
||||
self.client = OpenAI(
|
||||
api_key=args.gemini_api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
self.model_name = model_name
|
||||
# class GeminiEmbeddingFunction(object):
|
||||
# def __init__(self, api_key: str, base_url: str, model_name: str):
|
||||
# self.client = OpenAI(
|
||||
# api_key=args.gemini_api_key,
|
||||
# base_url=base_url
|
||||
# )
|
||||
# self.model_name = model_name
|
||||
|
||||
# def __call__(self, input: List[str]) -> List[List[float]]:
|
||||
# all_embeddings = []
|
||||
# for text in input:
|
||||
# response = self.client.embeddings.create(input=text, model=self.model_name)
|
||||
# embeddings = [record.embedding for record in response.data]
|
||||
# all_embeddings.append(np.array(embeddings[0]))
|
||||
# return all_embeddings
|
||||
|
||||
class SentenceTransformerEmbeddingFunction(object):
|
||||
def __init__(self, model_name: str):
|
||||
self.model = SentenceTransformer(model_name)
|
||||
|
||||
def __call__(self, input: List[str]) -> List[List[float]]:
|
||||
all_embeddings = []
|
||||
for text in input:
|
||||
response = self.client.embeddings.create(input=text, model=self.model_name)
|
||||
embeddings = [record.embedding for record in response.data]
|
||||
all_embeddings.append(np.array(embeddings[0]))
|
||||
return all_embeddings
|
||||
embeddings = self.model.encode(input)
|
||||
return embeddings.tolist()
|
||||
|
||||
# Define embedding function
|
||||
gemini_ef = GeminiEmbeddingFunction(
|
||||
api_key=args.gemini_api_key,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
model_name="text-embedding-004"
|
||||
)
|
||||
model_name = 'paraphrase-multilingual-mpnet-base-v2'
|
||||
embedding_function = SentenceTransformerEmbeddingFunction(model_name)
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
@ -95,7 +102,7 @@ def go(args):
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
db = chroma_client.create_collection(name=collection_name, embedding_function=gemini_ef)
|
||||
db = chroma_client.create_collection(name=collection_name, embedding_function=embedding_function)
|
||||
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
@ -126,7 +133,6 @@ def go(args):
|
||||
metadatas=[{"filename": file}],
|
||||
ids=[file[:-4] + str(i)])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -161,10 +167,10 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gemini_api_key",
|
||||
"--embedding_model",
|
||||
type=str,
|
||||
help="API key for the Gemini service",
|
||||
required=True
|
||||
default="paraphrase-multilingual-mpnet-base-v2",
|
||||
help="Sentence Transformer model name"
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user