mirror of
https://github.com/aimingmed/aimingmed-ai.git
synced 2026-02-04 22:25:34 +08:00
update
This commit is contained in:
parent
82435ebdbd
commit
7262180f84
@ -81,17 +81,10 @@ def go(args):
|
||||
# all_embeddings.append(np.array(embeddings[0]))
|
||||
# return all_embeddings
|
||||
|
||||
class SentenceTransformerEmbeddingFunction(object):
|
||||
def __init__(self, model_name: str):
|
||||
self.model = SentenceTransformer(model_name)
|
||||
|
||||
def __call__(self, input: List[str]) -> List[List[float]]:
|
||||
embeddings = self.model.encode(input)
|
||||
return embeddings.tolist()
|
||||
# Initialize embedding model (do this ONCE)
|
||||
model_embedding = SentenceTransformer('all-mpnet-base-v2') # Or a multilingual model
|
||||
|
||||
# Define embedding function
|
||||
model_name = 'paraphrase-multilingual-mpnet-base-v2'
|
||||
embedding_function = SentenceTransformerEmbeddingFunction(model_name)
|
||||
|
||||
# Create database, delete the database directory if it exists
|
||||
db_folder = "chroma_db"
|
||||
@ -102,7 +95,7 @@ def go(args):
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=db_path)
|
||||
collection_name = "rag_experiment"
|
||||
db = chroma_client.create_collection(name=collection_name, embedding_function=embedding_function)
|
||||
db = chroma_client.create_collection(name=collection_name)
|
||||
|
||||
|
||||
logger.info("Downloading artifact")
|
||||
@ -131,8 +124,9 @@ def go(args):
|
||||
for i, split in enumerate(all_splits):
|
||||
db.add(documents=[split.page_content],
|
||||
metadatas=[{"filename": file}],
|
||||
ids=[file[:-4] + str(i)])
|
||||
|
||||
ids=[f'{file[:-4]}-{str(i)}'],
|
||||
embeddings=[model_embedding.encode(split.page_content)]
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user