From 00869bd8391fa8dc0e1bdca2a8860c774d379ed3 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sat, 30 Mar 2024 00:17:48 +0800 Subject: [PATCH] add BaseRetrieverService, VectorstoreRetrieverService, EnsembleRetrieverService --- .../search_local_knowledgebase.py | 4 +- .../chatchat/server/file_rag/__init__.py | 0 .../document_loaders/FilteredCSVloader.py | 0 .../document_loaders/__init__.py | 0 .../document_loaders/mydocloader.py | 0 .../document_loaders/myimgloader.py | 0 .../document_loaders/mypdfloader.py | 0 .../document_loaders/mypptloader.py | 0 .../{ => file_rag}/document_loaders/ocr.py | 0 .../server/file_rag/retrievers/__init__.py | 3 ++ .../server/file_rag/retrievers/base.py | 24 ++++++++++ .../server/file_rag/retrievers/ensemble.py | 45 +++++++++++++++++++ .../server/file_rag/retrievers/vectorstore.py | 33 ++++++++++++++ .../{ => file_rag}/text_splitter/__init__.py | 0 .../text_splitter/ali_text_splitter.py | 0 .../chinese_recursive_text_splitter.py | 0 .../text_splitter/chinese_text_splitter.py | 0 .../text_splitter/zh_title_enhance.py | 0 .../chatchat/server/file_rag/utils.py | 8 ++++ .../kb_service/faiss_kb_service.py | 12 +++-- .../chatchat/server/knowledge_base/utils.py | 2 +- 21 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 chatchat-server/chatchat/server/file_rag/__init__.py rename chatchat-server/chatchat/server/{ => file_rag}/document_loaders/FilteredCSVloader.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/document_loaders/__init__.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/document_loaders/mydocloader.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/document_loaders/myimgloader.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/document_loaders/mypdfloader.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/document_loaders/mypptloader.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/document_loaders/ocr.py (100%) create mode 100644 chatchat-server/chatchat/server/file_rag/retrievers/__init__.py create mode 100644 chatchat-server/chatchat/server/file_rag/retrievers/base.py create mode 100644 chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py create mode 100644 chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py rename chatchat-server/chatchat/server/{ => file_rag}/text_splitter/__init__.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/text_splitter/ali_text_splitter.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/text_splitter/chinese_recursive_text_splitter.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/text_splitter/chinese_text_splitter.py (100%) rename chatchat-server/chatchat/server/{ => file_rag}/text_splitter/zh_title_enhance.py (100%) create mode 100644 chatchat-server/chatchat/server/file_rag/utils.py diff --git a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index a34e8301..7a7ef3ff 100644 --- a/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -49,7 +49,9 @@ def search_local_knowledgebase( database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data), query: str = Field(description="Query for Knowledge Search"), ): - '''''' + """ + 本地知识库检索 + """ tool_config = get_tool_config("search_local_knowledgebase") ret = search_knowledgebase(query=query, database=database, config=tool_config) return KBToolOutput(ret, database=database) diff --git a/chatchat-server/chatchat/server/file_rag/__init__.py b/chatchat-server/chatchat/server/file_rag/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chatchat-server/chatchat/server/document_loaders/FilteredCSVloader.py b/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py similarity index 100% rename from chatchat-server/chatchat/server/document_loaders/FilteredCSVloader.py rename to chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py diff --git a/chatchat-server/chatchat/server/document_loaders/__init__.py b/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py similarity index 100% rename from chatchat-server/chatchat/server/document_loaders/__init__.py rename to chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py diff --git a/chatchat-server/chatchat/server/document_loaders/mydocloader.py b/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py similarity index 100% rename from chatchat-server/chatchat/server/document_loaders/mydocloader.py rename to chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py diff --git a/chatchat-server/chatchat/server/document_loaders/myimgloader.py b/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py similarity index 100% rename from chatchat-server/chatchat/server/document_loaders/myimgloader.py rename to chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py diff --git a/chatchat-server/chatchat/server/document_loaders/mypdfloader.py b/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py similarity index 100% rename from chatchat-server/chatchat/server/document_loaders/mypdfloader.py rename to chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py diff --git a/chatchat-server/chatchat/server/document_loaders/mypptloader.py b/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py similarity index 100% rename from chatchat-server/chatchat/server/document_loaders/mypptloader.py rename to chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py diff --git a/chatchat-server/chatchat/server/document_loaders/ocr.py b/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py similarity index 100% rename from chatchat-server/chatchat/server/document_loaders/ocr.py rename to chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py diff --git a/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py b/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py new file mode 100644 index 00000000..2cf3617f --- /dev/null +++ b/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py @@ -0,0 +1,3 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService +from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService \ No newline at end of file diff --git a/chatchat-server/chatchat/server/file_rag/retrievers/base.py b/chatchat-server/chatchat/server/file_rag/retrievers/base.py new file mode 100644 index 00000000..1633bd20 --- /dev/null +++ b/chatchat-server/chatchat/server/file_rag/retrievers/base.py @@ -0,0 +1,24 @@ +from langchain.vectorstores import VectorStore +from abc import ABCMeta, abstractmethod + + +class BaseRetrieverService(metaclass=ABCMeta): + def __init__(self, **kwargs): + self.do_init(**kwargs) + + @abstractmethod + def do_init(self, **kwargs): + pass + + + @abstractmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + pass + + @abstractmethod + def get_related_documents(self, query: str): + pass diff --git a/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py b/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py new file mode 100644 index 00000000..cb035a65 --- /dev/null +++ b/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py @@ -0,0 +1,45 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from langchain.vectorstores import VectorStore +from langchain_core.retrievers import BaseRetriever +from langchain.retrievers import BM25Retriever, EnsembleRetriever + + +class EnsembleRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = None + + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + faiss_retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={ + "score_threshold": score_threshold, + "k": top_k + } + ) + from cutword import Cutter + cutter = Cutter() + docs = list(vectorstore.docstore._dict.values()) + bm25_retriever = BM25Retriever.from_documents( + docs, + preprocess_func=cutter.cutword + ) + bm25_retriever.k = top_k + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] + ) + return EnsembleRetrieverService(retriever=ensemble_retriever) + + def get_related_documents(self, query: str): + self.retriever.get_relevant_documents(query)[:self.top_k] diff --git a/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py b/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py new file mode 100644 index 00000000..d28ac0f7 --- /dev/null +++ b/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py @@ -0,0 +1,33 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from langchain.vectorstores import VectorStore +from langchain_core.retrievers import BaseRetriever + + +class VectorstoreRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = None + + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={ + "score_threshold": score_threshold, + "k": top_k + } + ) + return VectorstoreRetrieverService(retriever=retriever) + + def get_related_documents(self, query: str): + self.retriever.get_relevant_documents(query)[:self.top_k] diff --git a/chatchat-server/chatchat/server/text_splitter/__init__.py b/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py similarity index 100% rename from chatchat-server/chatchat/server/text_splitter/__init__.py rename to chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py diff --git a/chatchat-server/chatchat/server/text_splitter/ali_text_splitter.py b/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py similarity index 100% rename from chatchat-server/chatchat/server/text_splitter/ali_text_splitter.py rename to chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py diff --git a/chatchat-server/chatchat/server/text_splitter/chinese_recursive_text_splitter.py b/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py similarity index 100% rename from chatchat-server/chatchat/server/text_splitter/chinese_recursive_text_splitter.py rename to chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py diff --git a/chatchat-server/chatchat/server/text_splitter/chinese_text_splitter.py b/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py similarity index 100% rename from chatchat-server/chatchat/server/text_splitter/chinese_text_splitter.py rename to chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py diff --git a/chatchat-server/chatchat/server/text_splitter/zh_title_enhance.py b/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py similarity index 100% rename from chatchat-server/chatchat/server/text_splitter/zh_title_enhance.py rename to chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py diff --git a/chatchat-server/chatchat/server/file_rag/utils.py b/chatchat-server/chatchat/server/file_rag/utils.py new file mode 100644 index 00000000..d3fa0e42 --- /dev/null +++ b/chatchat-server/chatchat/server/file_rag/utils.py @@ -0,0 +1,8 @@ +from chatchat.server.file_rag.retrievers import BaseRetrieverService, VectorstoreRetrieverService + +Retrivals = { + "vectorstore": VectorstoreRetrieverService, +} + +def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService: + return Retrivals[type] \ No newline at end of file diff --git a/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py b/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py index 95f7cd64..9a61b395 100644 --- a/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py @@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get from chatchat.server.utils import get_Embeddings from langchain.docstore.document import Document from typing import List, Dict, Optional, Tuple +from chatchat.server.file_rag.utils import get_Retriever class FaissKBService(KBService): @@ -62,10 +63,15 @@ class FaissKBService(KBService): top_k: int, score_threshold: float = SCORE_THRESHOLD, ) -> List[Tuple[Document, float]]: - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) with self.load_vector_store().acquire() as vs: - docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) + retriever = get_Retriever("vectorstore").from_vectorstore( + vs, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + + # docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) return docs def do_add_doc(self, diff --git a/chatchat-server/chatchat/server/knowledge_base/utils.py b/chatchat-server/chatchat/server/knowledge_base/utils.py index 423c56eb..35c39092 100644 --- a/chatchat-server/chatchat/server/knowledge_base/utils.py +++ b/chatchat-server/chatchat/server/knowledge_base/utils.py @@ -11,7 +11,7 @@ from chatchat.configs import ( TEXT_SPLITTER_NAME, ) import importlib -from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance +from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance import langchain_community.document_loaders from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter