Support Chroma

This commit is contained in:
showmecodett 2024-01-28 21:42:03 +08:00
parent a5e758bf82
commit c94938bc83
5 changed files with 121 additions and 3 deletions

View File

@ -3,7 +3,7 @@ import os
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples"
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector, 全文检索引擎es, chromadb
DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量针对FAISS
@ -110,7 +110,8 @@ kbs_config = {
"milvus_kwargs":{
"search_params":{"metric_type": "L2"}, #在此处增加search_params
"index_params":{"metric_type": "L2","index_type": "HNSW"} # 在此处增加index_params
}
},
"chromadb": {}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。

View File

@ -47,6 +47,7 @@ llama-index==0.9.35
# pymilvus==2.3.4
# psycopg2==2.9.9
# pgvector==0.2.4
# chromadb==0.4.13
#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat
#autoawq==0.1.8 # For Int4
#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files

View File

@ -30,4 +30,5 @@ watchdog~=3.0.0
# volcengine>=1.0.119
# pymilvus>=2.3.4
# psycopg2==2.9.9
# pgvector>=0.2.4
# pgvector>=0.2.4
# chromadb==0.4.13

View File

@ -47,6 +47,7 @@ class SupportedVSType:
ZILLIZ = 'zilliz'
PG = 'pg'
ES = 'es'
CHROMADB = 'chromadb'
class KBService(ABC):
@ -319,6 +320,9 @@ class KBServiceFactory:
elif SupportedVSType.ES == vector_store_type:
from server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.CHROMADB == vector_store_type:
from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)

View File

@ -0,0 +1,111 @@
import uuid
from typing import Any, Dict, List, Optional, Tuple
import chromadb
from chromadb.api.types import (ID, CollectionMetadata, Embedding,
EmbeddingFunction, GetResult, IDs, Include,
Metadata, OneOrMany, QueryResult, Where)
from langchain.docstore.document import Document
from langchain_community.vectorstores.chroma import Chroma
from configs import SCORE_THRESHOLD, logger
from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
KBService, SupportedVSType,
score_threshold_process)
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
if not get_result['documents']:
return []
_metadatas = get_result['metadatas'] if get_result['metadatas'] else [{}] * len(get_result['documents'])
document_list = []
for page_content, metadata in zip(get_result['documents'], _metadatas):
document_list.append(Document(**{'page_content': page_content, 'metadata': metadata}))
return document_list
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
"""
from langchain_community.vectorstores.chroma import Chroma
"""
return [
# TODO: Chroma can do batch querying,
# we shouldn't hard code to the 1st result
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
class ChromaKBService(KBService):
vs_path: str
kb_path: str
client = None
collection = None
def vs_type(self) -> str:
return SupportedVSType.CHROMADB
def get_vs_path(self) -> str:
return get_vs_path(self.kb_name, self.embed_model)
def get_kb_path(self) -> str:
return get_kb_path(self.kb_name)
def do_init(self) -> None:
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
self.client = chromadb.PersistentClient(path=self.vs_path)
self.collection = self.client.get_or_create_collection(self.kb_name)
def do_create_kb(self) -> None:
# In ChromaDB, creating a KB is equivalent to creating a collection
self.collection = self.client.get_or_create_collection(self.kb_name)
def do_drop_kb(self):
# Dropping a KB is equivalent to deleting a collection in ChromaDB
try:
self.client.delete_collection(self.kb_name)
except ValueError as e:
if not str(e) == f"Collection {self.kb_name} does not exist.":
raise e
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
return _results_to_docs_and_scores(query_result)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
doc_infos = []
data = self._docs_to_embeddings(docs)
print(data)
ids = [str(uuid.uuid1()) for _ in range(len(data["texts"]))]
for _id, text, embedding, metadata in zip(ids, data["texts"], data["embeddings"], data["metadatas"]):
self.collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text)
doc_infos.append({"id": _id, "metadata": metadata})
return doc_infos
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
get_result: GetResult = self.collection.get(ids=ids)
return _get_result_to_documents(get_result)
def del_doc_by_ids(self, ids: List[str]) -> bool:
self.collection.delete(ids=ids)
return True
def do_clear_vs(self):
# Clearing the vector store might be equivalent to dropping and recreating the collection
self.do_drop_kb()
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
return self.collection.delete(where={"source": kb_file.filepath})