add BaseRetrieverService, VectorstoreRetrieverService, EnsembleRetrieverService

This commit is contained in:
imClumsyPanda 2024-03-30 00:17:48 +08:00
parent 6e9e31a32c
commit 00869bd839
21 changed files with 126 additions and 5 deletions

View File

@ -49,7 +49,9 @@ def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data),
query: str = Field(description="Query for Knowledge Search"),
):
''''''
"""
本地知识库检索
"""
tool_config = get_tool_config("search_local_knowledgebase")
ret = search_knowledgebase(query=query, database=database, config=tool_config)
return KBToolOutput(ret, database=database)

View File

@ -0,0 +1,3 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService

View File

@ -0,0 +1,24 @@
from langchain.vectorstores import VectorStore
from abc import ABCMeta, abstractmethod
class BaseRetrieverService(metaclass=ABCMeta):
def __init__(self, **kwargs):
self.do_init(**kwargs)
@abstractmethod
def do_init(self, **kwargs):
pass
@abstractmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
pass
@abstractmethod
def get_related_documents(self, query: str):
pass

View File

@ -0,0 +1,45 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
from langchain.retrievers import BM25Retriever, EnsembleRetriever
class EnsembleRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = None
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
faiss_retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
from cutword import Cutter
cutter = Cutter()
docs = list(vectorstore.docstore._dict.values())
bm25_retriever = BM25Retriever.from_documents(
docs,
preprocess_func=cutter.cutword
)
bm25_retriever.k = top_k
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)
return EnsembleRetrieverService(retriever=ensemble_retriever)
def get_related_documents(self, query: str):
self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,33 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
class VectorstoreRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = None
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
return VectorstoreRetrieverService(retriever=retriever)
def get_related_documents(self, query: str):
self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,8 @@
from chatchat.server.file_rag.retrievers import BaseRetrieverService, VectorstoreRetrieverService
Retrivals = {
"vectorstore": VectorstoreRetrieverService,
}
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
return Retrivals[type]

View File

@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get
from chatchat.server.utils import get_Embeddings
from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple
from chatchat.server.file_rag.utils import get_Retriever
class FaissKBService(KBService):
@ -62,10 +63,15 @@ class FaissKBService(KBService):
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
) -> List[Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
retriever = get_Retriever("vectorstore").from_vectorstore(
vs,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
# docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,

View File

@ -11,7 +11,7 @@ from chatchat.configs import (
TEXT_SPLITTER_NAME,
)
import importlib
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance
from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain_community.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter