mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 23:15:53 +08:00
add BaseRetrieverService, VectorstoreRetrieverService, EnsembleRetrieverService
This commit is contained in:
parent
6e9e31a32c
commit
00869bd839
@ -49,7 +49,9 @@ def search_local_knowledgebase(
|
|||||||
database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data),
|
database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data),
|
||||||
query: str = Field(description="Query for Knowledge Search"),
|
query: str = Field(description="Query for Knowledge Search"),
|
||||||
):
|
):
|
||||||
''''''
|
"""
|
||||||
|
本地知识库检索
|
||||||
|
"""
|
||||||
tool_config = get_tool_config("search_local_knowledgebase")
|
tool_config = get_tool_config("search_local_knowledgebase")
|
||||||
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
||||||
return KBToolOutput(ret, database=database)
|
return KBToolOutput(ret, database=database)
|
||||||
|
|||||||
@ -0,0 +1,3 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||||
|
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
|
||||||
|
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService
|
||||||
24
chatchat-server/chatchat/server/file_rag/retrievers/base.py
Normal file
24
chatchat-server/chatchat/server/file_rag/retrievers/base.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRetrieverService(metaclass=ABCMeta):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.do_init(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def do_init(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def from_vectorstore(
|
||||||
|
vectorstore: VectorStore,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: int or float,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_related_documents(self, query: str):
|
||||||
|
pass
|
||||||
@ -0,0 +1,45 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class EnsembleRetrieverService(BaseRetrieverService):
|
||||||
|
def do_init(
|
||||||
|
self,
|
||||||
|
retriever: BaseRetriever = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
):
|
||||||
|
self.vs = None
|
||||||
|
self.top_k = top_k
|
||||||
|
self.retriever = None
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_vectorstore(
|
||||||
|
vectorstore: VectorStore,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: int or float,
|
||||||
|
):
|
||||||
|
faiss_retriever = vectorstore.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"k": top_k
|
||||||
|
}
|
||||||
|
)
|
||||||
|
from cutword import Cutter
|
||||||
|
cutter = Cutter()
|
||||||
|
docs = list(vectorstore.docstore._dict.values())
|
||||||
|
bm25_retriever = BM25Retriever.from_documents(
|
||||||
|
docs,
|
||||||
|
preprocess_func=cutter.cutword
|
||||||
|
)
|
||||||
|
bm25_retriever.k = top_k
|
||||||
|
ensemble_retriever = EnsembleRetriever(
|
||||||
|
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
|
||||||
|
)
|
||||||
|
return EnsembleRetrieverService(retriever=ensemble_retriever)
|
||||||
|
|
||||||
|
def get_related_documents(self, query: str):
|
||||||
|
self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class VectorstoreRetrieverService(BaseRetrieverService):
|
||||||
|
def do_init(
|
||||||
|
self,
|
||||||
|
retriever: BaseRetriever = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
):
|
||||||
|
self.vs = None
|
||||||
|
self.top_k = top_k
|
||||||
|
self.retriever = None
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_vectorstore(
|
||||||
|
vectorstore: VectorStore,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: int or float,
|
||||||
|
):
|
||||||
|
retriever = vectorstore.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"k": top_k
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return VectorstoreRetrieverService(retriever=retriever)
|
||||||
|
|
||||||
|
def get_related_documents(self, query: str):
|
||||||
|
self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||||
8
chatchat-server/chatchat/server/file_rag/utils.py
Normal file
8
chatchat-server/chatchat/server/file_rag/utils.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers import BaseRetrieverService, VectorstoreRetrieverService
|
||||||
|
|
||||||
|
Retrivals = {
|
||||||
|
"vectorstore": VectorstoreRetrieverService,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
|
||||||
|
return Retrivals[type]
|
||||||
@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get
|
|||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.utils import get_Embeddings
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from typing import List, Dict, Optional, Tuple
|
from typing import List, Dict, Optional, Tuple
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
class FaissKBService(KBService):
|
class FaissKBService(KBService):
|
||||||
@ -62,10 +63,15 @@ class FaissKBService(KBService):
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
|
||||||
embeddings = embed_func.embed_query(query)
|
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
|
vs,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
|
||||||
|
# docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from chatchat.configs import (
|
|||||||
TEXT_SPLITTER_NAME,
|
TEXT_SPLITTER_NAME,
|
||||||
)
|
)
|
||||||
import importlib
|
import importlib
|
||||||
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance
|
from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||||
import langchain_community.document_loaders
|
import langchain_community.document_loaders
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
|
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user