mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-28 17:53:33 +08:00
add BaseRetrieverService, VectorstoreRetrieverService, EnsembleRetrieverService
This commit is contained in:
parent
6e9e31a32c
commit
00869bd839
@ -49,7 +49,9 @@ def search_local_knowledgebase(
|
||||
database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data),
|
||||
query: str = Field(description="Query for Knowledge Search"),
|
||||
):
|
||||
''''''
|
||||
"""
|
||||
本地知识库检索
|
||||
"""
|
||||
tool_config = get_tool_config("search_local_knowledgebase")
|
||||
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
||||
return KBToolOutput(ret, database=database)
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
|
||||
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService
|
||||
24
chatchat-server/chatchat/server/file_rag/retrievers/base.py
Normal file
24
chatchat-server/chatchat/server/file_rag/retrievers/base.py
Normal file
@ -0,0 +1,24 @@
|
||||
from langchain.vectorstores import VectorStore
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseRetrieverService(metaclass=ABCMeta):
|
||||
def __init__(self, **kwargs):
|
||||
self.do_init(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def do_init(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def from_vectorstore(
|
||||
vectorstore: VectorStore,
|
||||
top_k: int,
|
||||
score_threshold: int or float,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_related_documents(self, query: str):
|
||||
pass
|
||||
@ -0,0 +1,45 @@
|
||||
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
||||
|
||||
|
||||
class EnsembleRetrieverService(BaseRetrieverService):
|
||||
def do_init(
|
||||
self,
|
||||
retriever: BaseRetriever = None,
|
||||
top_k: int = 5,
|
||||
):
|
||||
self.vs = None
|
||||
self.top_k = top_k
|
||||
self.retriever = None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_vectorstore(
|
||||
vectorstore: VectorStore,
|
||||
top_k: int,
|
||||
score_threshold: int or float,
|
||||
):
|
||||
faiss_retriever = vectorstore.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={
|
||||
"score_threshold": score_threshold,
|
||||
"k": top_k
|
||||
}
|
||||
)
|
||||
from cutword import Cutter
|
||||
cutter = Cutter()
|
||||
docs = list(vectorstore.docstore._dict.values())
|
||||
bm25_retriever = BM25Retriever.from_documents(
|
||||
docs,
|
||||
preprocess_func=cutter.cutword
|
||||
)
|
||||
bm25_retriever.k = top_k
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
|
||||
)
|
||||
return EnsembleRetrieverService(retriever=ensemble_retriever)
|
||||
|
||||
def get_related_documents(self, query: str):
|
||||
self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||
@ -0,0 +1,33 @@
|
||||
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class VectorstoreRetrieverService(BaseRetrieverService):
|
||||
def do_init(
|
||||
self,
|
||||
retriever: BaseRetriever = None,
|
||||
top_k: int = 5,
|
||||
):
|
||||
self.vs = None
|
||||
self.top_k = top_k
|
||||
self.retriever = None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_vectorstore(
|
||||
vectorstore: VectorStore,
|
||||
top_k: int,
|
||||
score_threshold: int or float,
|
||||
):
|
||||
retriever = vectorstore.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={
|
||||
"score_threshold": score_threshold,
|
||||
"k": top_k
|
||||
}
|
||||
)
|
||||
return VectorstoreRetrieverService(retriever=retriever)
|
||||
|
||||
def get_related_documents(self, query: str):
|
||||
self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||
8
chatchat-server/chatchat/server/file_rag/utils.py
Normal file
8
chatchat-server/chatchat/server/file_rag/utils.py
Normal file
@ -0,0 +1,8 @@
|
||||
from chatchat.server.file_rag.retrievers import BaseRetrieverService, VectorstoreRetrieverService
|
||||
|
||||
Retrivals = {
|
||||
"vectorstore": VectorstoreRetrieverService,
|
||||
}
|
||||
|
||||
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
|
||||
return Retrivals[type]
|
||||
@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
@ -62,10 +63,15 @@ class FaissKBService(KBService):
|
||||
top_k: int,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
vs,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
|
||||
# docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
|
||||
@ -11,7 +11,7 @@ from chatchat.configs import (
|
||||
TEXT_SPLITTER_NAME,
|
||||
)
|
||||
import importlib
|
||||
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||
from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||
import langchain_community.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user