This commit is contained in:
leehk 2025-02-22 19:07:22 +08:00
parent 82435ebdbd
commit 7262180f84

View File

@ -81,17 +81,10 @@ def go(args):
# all_embeddings.append(np.array(embeddings[0]))
# return all_embeddings
class SentenceTransformerEmbeddingFunction(object):
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name)
def __call__(self, input: List[str]) -> List[List[float]]:
embeddings = self.model.encode(input)
return embeddings.tolist()
# Initialize embedding model (do this ONCE)
model_embedding = SentenceTransformer('all-mpnet-base-v2') # Or a multilingual model
# Define embedding function
model_name = 'paraphrase-multilingual-mpnet-base-v2'
embedding_function = SentenceTransformerEmbeddingFunction(model_name)
# Create database, delete the database directory if it exists
db_folder = "chroma_db"
@ -102,7 +95,7 @@ def go(args):
chroma_client = chromadb.PersistentClient(path=db_path)
collection_name = "rag_experiment"
db = chroma_client.create_collection(name=collection_name, embedding_function=embedding_function)
db = chroma_client.create_collection(name=collection_name)
logger.info("Downloading artifact")
@ -131,8 +124,9 @@ def go(args):
for i, split in enumerate(all_splits):
db.add(documents=[split.page_content],
metadatas=[{"filename": file}],
ids=[file[:-4] + str(i)])
ids=[f'{file[:-4]}-{str(i)}'],
embeddings=[model_embedding.encode(split.page_content)]
)
if __name__ == "__main__":