mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-09 00:43:33 +08:00
update
This commit is contained in:
parent
82435ebdbd
commit
7262180f84
@ -81,17 +81,10 @@ def go(args):
|
|||||||
# all_embeddings.append(np.array(embeddings[0]))
|
# all_embeddings.append(np.array(embeddings[0]))
|
||||||
# return all_embeddings
|
# return all_embeddings
|
||||||
|
|
||||||
class SentenceTransformerEmbeddingFunction(object):
|
|
||||||
def __init__(self, model_name: str):
|
|
||||||
self.model = SentenceTransformer(model_name)
|
|
||||||
|
|
||||||
def __call__(self, input: List[str]) -> List[List[float]]:
|
# Initialize embedding model (do this ONCE)
|
||||||
embeddings = self.model.encode(input)
|
model_embedding = SentenceTransformer('all-mpnet-base-v2') # Or a multilingual model
|
||||||
return embeddings.tolist()
|
|
||||||
|
|
||||||
# Define embedding function
|
|
||||||
model_name = 'paraphrase-multilingual-mpnet-base-v2'
|
|
||||||
embedding_function = SentenceTransformerEmbeddingFunction(model_name)
|
|
||||||
|
|
||||||
# Create database, delete the database directory if it exists
|
# Create database, delete the database directory if it exists
|
||||||
db_folder = "chroma_db"
|
db_folder = "chroma_db"
|
||||||
@ -102,7 +95,7 @@ def go(args):
|
|||||||
|
|
||||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||||
collection_name = "rag_experiment"
|
collection_name = "rag_experiment"
|
||||||
db = chroma_client.create_collection(name=collection_name, embedding_function=embedding_function)
|
db = chroma_client.create_collection(name=collection_name)
|
||||||
|
|
||||||
|
|
||||||
logger.info("Downloading artifact")
|
logger.info("Downloading artifact")
|
||||||
@ -131,8 +124,9 @@ def go(args):
|
|||||||
for i, split in enumerate(all_splits):
|
for i, split in enumerate(all_splits):
|
||||||
db.add(documents=[split.page_content],
|
db.add(documents=[split.page_content],
|
||||||
metadatas=[{"filename": file}],
|
metadatas=[{"filename": file}],
|
||||||
ids=[file[:-4] + str(i)])
|
ids=[f'{file[:-4]}-{str(i)}'],
|
||||||
|
embeddings=[model_embedding.encode(split.page_content)]
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user