Merge pull request #2820 from showmecodett/feat-chromadb

支持Chroma向量数据库
This commit is contained in:
zR 2024-01-29 20:59:15 +08:00 committed by GitHub
commit 4f5824e964
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 3 deletions

View File

@ -3,7 +3,7 @@ import os
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples"
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector, 全文检索引擎es, chromadb
DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量针对FAISS
@ -110,7 +110,8 @@ kbs_config = {
"milvus_kwargs":{
"search_params":{"metric_type": "L2"}, #在此处增加search_params
"index_params":{"metric_type": "L2","index_type": "HNSW"} # 在此处增加index_params
}
},
"chromadb": {}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。

View File

@ -47,6 +47,7 @@ llama-index==0.9.35
# pymilvus==2.3.4
# psycopg2==2.9.9
# pgvector==0.2.4
# chromadb==0.4.13
#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat
#autoawq==0.1.8 # For Int4
#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files

View File

@ -30,4 +30,5 @@ watchdog~=3.0.0
# volcengine>=1.0.119
# pymilvus>=2.3.4
# psycopg2==2.9.9
# pgvector>=0.2.4
# pgvector>=0.2.4
# chromadb==0.4.13

View File

@ -47,6 +47,7 @@ class SupportedVSType:
ZILLIZ = 'zilliz'
PG = 'pg'
ES = 'es'
CHROMADB = 'chromadb'
class KBService(ABC):
@ -319,6 +320,9 @@ class KBServiceFactory:
elif SupportedVSType.ES == vector_store_type:
from server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.CHROMADB == vector_store_type:
from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)

View File

@ -0,0 +1,107 @@
import uuid
from typing import Any, Dict, List, Tuple
import chromadb
from chromadb.api.types import (GetResult, QueryResult)
from langchain.docstore.document import Document
from configs import SCORE_THRESHOLD
from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
KBService, SupportedVSType)
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
if not get_result['documents']:
return []
_metadatas = get_result['metadatas'] if get_result['metadatas'] else [{}] * len(get_result['documents'])
document_list = []
for page_content, metadata in zip(get_result['documents'], _metadatas):
document_list.append(Document(**{'page_content': page_content, 'metadata': metadata}))
return document_list
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
"""
from langchain_community.vectorstores.chroma import Chroma
"""
return [
# TODO: Chroma can do batch querying,
# we shouldn't hard code to the 1st result
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
class ChromaKBService(KBService):
vs_path: str
kb_path: str
client = None
collection = None
def vs_type(self) -> str:
return SupportedVSType.CHROMADB
def get_vs_path(self) -> str:
return get_vs_path(self.kb_name, self.embed_model)
def get_kb_path(self) -> str:
return get_kb_path(self.kb_name)
def do_init(self) -> None:
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
self.client = chromadb.PersistentClient(path=self.vs_path)
self.collection = self.client.get_or_create_collection(self.kb_name)
def do_create_kb(self) -> None:
# In ChromaDB, creating a KB is equivalent to creating a collection
self.collection = self.client.get_or_create_collection(self.kb_name)
def do_drop_kb(self):
# Dropping a KB is equivalent to deleting a collection in ChromaDB
try:
self.client.delete_collection(self.kb_name)
except ValueError as e:
if not str(e) == f"Collection {self.kb_name} does not exist.":
raise e
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
return _results_to_docs_and_scores(query_result)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
doc_infos = []
data = self._docs_to_embeddings(docs)
print(data)
ids = [str(uuid.uuid1()) for _ in range(len(data["texts"]))]
for _id, text, embedding, metadata in zip(ids, data["texts"], data["embeddings"], data["metadatas"]):
self.collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text)
doc_infos.append({"id": _id, "metadata": metadata})
return doc_infos
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
get_result: GetResult = self.collection.get(ids=ids)
return _get_result_to_documents(get_result)
def del_doc_by_ids(self, ids: List[str]) -> bool:
self.collection.delete(ids=ids)
return True
def do_clear_vs(self):
# Clearing the vector store might be equivalent to dropping and recreating the collection
self.do_drop_kb()
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
return self.collection.delete(where={"source": kb_file.filepath})