mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-08 16:37:29 +08:00
update
This commit is contained in:
parent
702b1eb874
commit
82435ebdbd
@ -6,4 +6,5 @@ main:
|
|||||||
etl:
|
etl:
|
||||||
document_folder: "documents"
|
document_folder: "documents"
|
||||||
path_document_folder: "../../../../data"
|
path_document_folder: "../../../../data"
|
||||||
|
embedding_model: "paraphrase-multilingual-mpnet-base-v2"
|
||||||
|
|
||||||
|
|||||||
@ -64,7 +64,7 @@ def go(config: DictConfig):
|
|||||||
"output_artifact": "chromdb.zip",
|
"output_artifact": "chromdb.zip",
|
||||||
"output_type": "chromdb",
|
"output_type": "chromdb",
|
||||||
"output_description": "Scanned Documents in pdf to be read amd stored in chromdb",
|
"output_description": "Scanned Documents in pdf to be read amd stored in chromdb",
|
||||||
"gemini_api_key": GEMINI_API_KEY,
|
"embedding_model": config["etl"]["embedding_model"]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -21,8 +21,8 @@ entry_points:
|
|||||||
description: Description for the artifact
|
description: Description for the artifact
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
gemini_api_key:
|
embedding_model:
|
||||||
description: API key for Gemini
|
description: Fully-qualified name for the embedding model
|
||||||
type: string
|
type: string
|
||||||
|
|
||||||
|
|
||||||
@ -31,4 +31,4 @@ entry_points:
|
|||||||
--output_artifact {output_artifact} \
|
--output_artifact {output_artifact} \
|
||||||
--output_type {output_type} \
|
--output_type {output_type} \
|
||||||
--output_description {output_description} \
|
--output_description {output_description} \
|
||||||
--gemini_api_key {gemini_api_key}
|
--embedding_model {embedding_model}
|
||||||
@ -10,6 +10,7 @@ build_dependencies:
|
|||||||
- pytesseract
|
- pytesseract
|
||||||
- pdf2image
|
- pdf2image
|
||||||
- langchain
|
- langchain
|
||||||
|
- sentence_transformers
|
||||||
# Dependencies required to run the project.
|
# Dependencies required to run the project.
|
||||||
dependencies:
|
dependencies:
|
||||||
- mlflow==2.8.1
|
- mlflow==2.8.1
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import wandb
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from openai import OpenAI
|
# from openai import OpenAI
|
||||||
from typing import List
|
from typing import List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytesseract as pt
|
import pytesseract as pt
|
||||||
@ -17,10 +17,12 @@ from pdf2image import convert_from_path
|
|||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(asctime)-15s %(message)s")
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
def extract_text_from_pdf_ocr(pdf_path):
|
def extract_text_from_pdf_ocr(pdf_path):
|
||||||
try:
|
try:
|
||||||
@ -50,41 +52,46 @@ def go(args):
|
|||||||
run.config.update(args)
|
run.config.update(args)
|
||||||
|
|
||||||
# Setup the Gemini client
|
# Setup the Gemini client
|
||||||
client = OpenAI(
|
# client = OpenAI(
|
||||||
api_key=args.gemini_api_key,
|
# api_key=args.gemini_api_key,
|
||||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
|
# base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
def get_google_embedding(text: str) -> List[float]:
|
# def get_google_embedding(text: str) -> List[float]:
|
||||||
response = client.embeddings.create(
|
# response = client.embeddings.create(
|
||||||
model="text-embedding-004",
|
# model="text-embedding-004",
|
||||||
input=text
|
# input=text
|
||||||
)
|
# )
|
||||||
return response.data[0].embedding
|
# return response.data[0].embedding
|
||||||
|
|
||||||
class GeminiEmbeddingFunction(object):
|
# class GeminiEmbeddingFunction(object):
|
||||||
def __init__(self, api_key: str, base_url: str, model_name: str):
|
# def __init__(self, api_key: str, base_url: str, model_name: str):
|
||||||
self.client = OpenAI(
|
# self.client = OpenAI(
|
||||||
api_key=args.gemini_api_key,
|
# api_key=args.gemini_api_key,
|
||||||
base_url=base_url
|
# base_url=base_url
|
||||||
)
|
# )
|
||||||
self.model_name = model_name
|
# self.model_name = model_name
|
||||||
|
|
||||||
|
# def __call__(self, input: List[str]) -> List[List[float]]:
|
||||||
|
# all_embeddings = []
|
||||||
|
# for text in input:
|
||||||
|
# response = self.client.embeddings.create(input=text, model=self.model_name)
|
||||||
|
# embeddings = [record.embedding for record in response.data]
|
||||||
|
# all_embeddings.append(np.array(embeddings[0]))
|
||||||
|
# return all_embeddings
|
||||||
|
|
||||||
|
class SentenceTransformerEmbeddingFunction(object):
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
self.model = SentenceTransformer(model_name)
|
||||||
|
|
||||||
def __call__(self, input: List[str]) -> List[List[float]]:
|
def __call__(self, input: List[str]) -> List[List[float]]:
|
||||||
all_embeddings = []
|
embeddings = self.model.encode(input)
|
||||||
for text in input:
|
return embeddings.tolist()
|
||||||
response = self.client.embeddings.create(input=text, model=self.model_name)
|
|
||||||
embeddings = [record.embedding for record in response.data]
|
|
||||||
all_embeddings.append(np.array(embeddings[0]))
|
|
||||||
return all_embeddings
|
|
||||||
|
|
||||||
# Define embedding function
|
# Define embedding function
|
||||||
gemini_ef = GeminiEmbeddingFunction(
|
model_name = 'paraphrase-multilingual-mpnet-base-v2'
|
||||||
api_key=args.gemini_api_key,
|
embedding_function = SentenceTransformerEmbeddingFunction(model_name)
|
||||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
|
||||||
model_name="text-embedding-004"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create database, delete the database directory if it exists
|
# Create database, delete the database directory if it exists
|
||||||
db_folder = "chroma_db"
|
db_folder = "chroma_db"
|
||||||
@ -95,7 +102,7 @@ def go(args):
|
|||||||
|
|
||||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||||
collection_name = "rag_experiment"
|
collection_name = "rag_experiment"
|
||||||
db = chroma_client.create_collection(name=collection_name, embedding_function=gemini_ef)
|
db = chroma_client.create_collection(name=collection_name, embedding_function=embedding_function)
|
||||||
|
|
||||||
|
|
||||||
logger.info("Downloading artifact")
|
logger.info("Downloading artifact")
|
||||||
@ -127,7 +134,6 @@ def go(args):
|
|||||||
ids=[file[:-4] + str(i)])
|
ids=[file[:-4] + str(i)])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
parser = argparse.ArgumentParser(description="A very basic data cleaning")
|
||||||
@ -161,10 +167,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gemini_api_key",
|
"--embedding_model",
|
||||||
type=str,
|
type=str,
|
||||||
help="API key for the Gemini service",
|
default="paraphrase-multilingual-mpnet-base-v2",
|
||||||
required=True
|
help="Sentence Transformer model name"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user