From b56283eb01699f03c228f466234e30e3d6550392 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Mon, 10 Jun 2024 00:36:03 +0800 Subject: [PATCH] Feature(File RAG): add file_rag in chatchat-server, add ensemble retriever and vectorstore retriever. --- .../search_local_knowledgebase.py | 7 +-- .../chatchat/server/chat/file_chat.py | 6 +-- .../chatchat/server/file_rag/__init__.py | 0 .../document_loaders/FilteredCSVloader.py | 0 .../document_loaders/__init__.py | 0 .../document_loaders/mydocloader.py | 0 .../document_loaders/myimgloader.py | 2 +- .../document_loaders/mypdfloader.py | 2 +- .../document_loaders/mypptloader.py | 0 .../{ => file_rag}/document_loaders/ocr.py | 0 .../server/file_rag/retrievers/__init__.py | 3 ++ .../server/file_rag/retrievers/base.py | 24 ++++++++++ .../server/file_rag/retrievers/ensemble.py | 47 +++++++++++++++++++ .../server/file_rag/retrievers/vectorstore.py | 33 +++++++++++++ .../{ => file_rag}/text_splitter/__init__.py | 0 .../text_splitter/ali_text_splitter.py | 0 .../chinese_recursive_text_splitter.py | 0 .../text_splitter/chinese_text_splitter.py | 0 .../text_splitter/zh_title_enhance.py | 0 .../chatchat/server/file_rag/utils.py | 13 +++++ .../server/knowledge_base/kb_doc_api.py | 3 +- .../kb_service/chromadb_kb_service.py | 12 +++-- .../kb_service/es_kb_service.py | 9 +++- .../kb_service/faiss_kb_service.py | 13 +++-- .../kb_service/milvus_kb_service.py | 16 +++++-- .../kb_service/pg_kb_service.py | 12 +++-- .../kb_service/zilliz_kb_service.py | 15 +++--- .../chatchat/server/knowledge_base/utils.py | 2 +- .../chatchat/webui_pages/dialogue/dialogue.py | 24 +++++----- libs/chatchat-server/pyproject.toml | 2 + 30 files changed, 198 insertions(+), 47 deletions(-) create mode 100644 libs/chatchat-server/chatchat/server/file_rag/__init__.py rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/FilteredCSVloader.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/__init__.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/mydocloader.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/myimgloader.py (92%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/mypdfloader.py (98%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/mypptloader.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/ocr.py (100%) create mode 100644 libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py create mode 100644 libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py create mode 100644 libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py create mode 100644 libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/__init__.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/ali_text_splitter.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/chinese_recursive_text_splitter.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/chinese_text_splitter.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/zh_title_enhance.py (100%) create mode 100644 libs/chatchat-server/chatchat/server/file_rag/utils.py diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index e7524f2a..5f8fde72 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -1,13 +1,14 @@ from urllib.parse import urlencode from chatchat.server.utils import get_tool_config from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool, BaseToolOutput +from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput from chatchat.server.knowledge_base.kb_api import list_kbs from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId from chatchat.configs import KB_INFO -template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." +template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on " + "this knowledge use this tool. The 'database' should be one of the above [{key}].") KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) template_knowledge = template.format(KB_info=KB_info_str, key="samples") @@ -49,7 +50,7 @@ def search_local_knowledgebase( database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]), query: str = Field(description="Query for Knowledge Search"), ): - '''''' + """""" tool_config = get_tool_config("search_local_knowledgebase") ret = search_knowledgebase(query=query, database=database, config=tool_config) return KBToolOutput(ret, database=database) diff --git a/libs/chatchat-server/chatchat/server/chat/file_chat.py b/libs/chatchat-server/chatchat/server/chat/file_chat.py index f2a8e67a..a0e67d0d 100644 --- a/libs/chatchat-server/chatchat/server/chat/file_chat.py +++ b/libs/chatchat-server/chatchat/server/chat/file_chat.py @@ -63,10 +63,10 @@ def upload_temp_docs( chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), ) -> BaseResponse: - ''' + """ 将文件保存到临时目录,并进行向量化。 返回临时目录名称作为ID,同时也是临时向量库的ID。 - ''' + """ if prev_id is not None: memo_faiss_pool.pop(prev_id) @@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= docs = [x[0] for x in docs] context = "\n".join([doc.page_content for doc in docs]) - if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板 + if len(docs) == 0: # 如果没有找到相关文档,使用Empty模板 prompt_template = get_prompt_template("knowledge_base_chat", "empty") else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) diff --git a/libs/chatchat-server/chatchat/server/file_rag/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/chatchat-server/chatchat/server/document_loaders/FilteredCSVloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/FilteredCSVloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/__init__.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/mydocloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/mydocloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/myimgloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py similarity index 92% rename from libs/chatchat-server/chatchat/server/document_loaders/myimgloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py index 6b195cce..c6fda01e 100644 --- a/libs/chatchat-server/chatchat/server/document_loaders/myimgloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py @@ -1,6 +1,6 @@ from typing import List from langchain_community.document_loaders.unstructured import UnstructuredFileLoader -from chatchat.server.document_loaders.ocr import get_ocr +from chatchat.server.file_rag.document_loaders.ocr import get_ocr class RapidOCRLoader(UnstructuredFileLoader): diff --git a/libs/chatchat-server/chatchat/server/document_loaders/mypdfloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py similarity index 98% rename from libs/chatchat-server/chatchat/server/document_loaders/mypdfloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py index 9e8796a4..c6a178f8 100644 --- a/libs/chatchat-server/chatchat/server/document_loaders/mypdfloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py @@ -4,7 +4,7 @@ import cv2 from PIL import Image import numpy as np from chatchat.configs import PDF_OCR_THRESHOLD -from chatchat.server.document_loaders.ocr import get_ocr +from chatchat.server.file_rag.document_loaders.ocr import get_ocr import tqdm diff --git a/libs/chatchat-server/chatchat/server/document_loaders/mypptloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/mypptloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/ocr.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/ocr.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py new file mode 100644 index 00000000..2cf3617f --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py @@ -0,0 +1,3 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService +from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py new file mode 100644 index 00000000..7e4d0646 --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py @@ -0,0 +1,24 @@ +from langchain.vectorstores import VectorStore +from abc import ABCMeta, abstractmethod + + +class BaseRetrieverService(metaclass=ABCMeta): + def __init__(self, **kwargs): + self.do_init(**kwargs) + + @abstractmethod + def do_init(self, **kwargs): + pass + + + @abstractmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + pass + + @abstractmethod + def get_relevant_documents(self, query: str): + pass diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py new file mode 100644 index 00000000..cb09b633 --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py @@ -0,0 +1,47 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from langchain.vectorstores import VectorStore +from langchain_core.retrievers import BaseRetriever +from langchain_community.retrievers import BM25Retriever +from langchain.retrievers import EnsembleRetriever + + +class EnsembleRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = retriever + + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + faiss_retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={ + "score_threshold": score_threshold, + "k": top_k + } + ) + # TODO: 换个不用torch的实现方式 + from cutword.cutword import Cutter + cutter = Cutter() + docs = list(vectorstore.docstore._dict.values()) + bm25_retriever = BM25Retriever.from_documents( + docs, + preprocess_func=cutter.cutword + ) + bm25_retriever.k = top_k + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] + ) + return EnsembleRetrieverService(retriever=ensemble_retriever) + + def get_relevant_documents(self, query: str): + return self.retriever.get_relevant_documents(query)[:self.top_k] diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py new file mode 100644 index 00000000..b6d382fa --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py @@ -0,0 +1,33 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from langchain.vectorstores import VectorStore +from langchain_core.retrievers import BaseRetriever + + +class VectorstoreRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = retriever + + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={ + "score_threshold": score_threshold, + "k": top_k + } + ) + return VectorstoreRetrieverService(retriever=retriever) + + def get_relevant_documents(self, query: str): + return self.retriever.get_relevant_documents(query)[:self.top_k] diff --git a/libs/chatchat-server/chatchat/server/text_splitter/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/__init__.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/ali_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/ali_text_splitter.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/chinese_recursive_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/chinese_recursive_text_splitter.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/chinese_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/chinese_text_splitter.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/zh_title_enhance.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/zh_title_enhance.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py diff --git a/libs/chatchat-server/chatchat/server/file_rag/utils.py b/libs/chatchat-server/chatchat/server/file_rag/utils.py new file mode 100644 index 00000000..ddf64e3d --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/utils.py @@ -0,0 +1,13 @@ +from chatchat.server.file_rag.retrievers import ( + BaseRetrieverService, + VectorstoreRetrieverService, + EnsembleRetrieverService, +) + +Retrivals = { + "vectorstore": VectorstoreRetrieverService, + "ensemble": EnsembleRetrieverService, +} + +def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService: + return Retrivals[type] \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py index 0e92c091..3e40786d 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py @@ -37,7 +37,8 @@ def search_docs( if kb is not None: if query: docs = kb.search_docs(query, top_k, score_threshold) - data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs] elif file_name or metadata: data = kb.list_docs(file_name=file_name, metadata=metadata) for d in data: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py index 0834c87d..c6c46622 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py @@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.utils import get_Embeddings +from chatchat.server.file_rag.utils import get_Retriever def _get_result_to_documents(get_result: GetResult) -> List[Document]: @@ -75,10 +76,13 @@ class ChromaKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[ Tuple[Document, float]]: - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k) - return _results_to_docs_and_scores(query_result) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.collection, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: doc_infos = [] diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py index 19813bf1..aef63710 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py @@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings from elasticsearch import Elasticsearch, BadRequestError from chatchat.configs import kbs_config, KB_ROOT_PATH +from chatchat.server.file_rag.utils import get_Retriever import logging @@ -107,8 +108,12 @@ class ESKBService(KBService): def do_search(self, query:str, top_k: int, score_threshold: float): # 文本相似性检索 - docs = self.db.similarity_search_with_score(query=query, - k=top_k) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.db, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) return docs def get_doc_by_ids(self, ids: List[str]) -> List[Document]: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py index 95f7cd64..52738ae8 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py @@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path -from chatchat.server.utils import get_Embeddings from langchain.docstore.document import Document -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Tuple +from chatchat.server.file_rag.utils import get_Retriever class FaissKBService(KBService): @@ -62,10 +62,13 @@ class FaissKBService(KBService): top_k: int, score_threshold: float = SCORE_THRESHOLD, ) -> List[Tuple[Document, float]]: - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) with self.load_vector_store().acquire() as vs: - docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) + retriever = get_Retriever("ensemble").from_vectorstore( + vs, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) return docs def do_add_doc(self, diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py index ab0a77e9..8eddb5f4 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py @@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ score_threshold_process from chatchat.server.knowledge_base.utils import KnowledgeFile -from chatchat.server.utils import get_Embeddings +from chatchat.server.file_rag.utils import get_Retriever class MilvusKBService(KBService): @@ -67,10 +67,16 @@ class MilvusKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_milvus() - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) - return score_threshold_process(score_threshold, top_k, docs) + # embed_func = get_Embeddings(self.embed_model) + # embeddings = embed_func.embed_query(query) + # docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.milvus, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: for doc in docs: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py index 8c3a0cf6..473c7f30 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py @@ -15,6 +15,7 @@ import shutil import sqlalchemy from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session +from chatchat.server.file_rag.utils import get_Retriever class PGKBService(KBService): @@ -60,10 +61,13 @@ class PGKBService(KBService): shutil.rmtree(self.kb_path) def do_search(self, query: str, top_k: int, score_threshold: float): - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) - return score_threshold_process(score_threshold, top_k, docs) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.pg_vector, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: ids = self.pg_vector.add_documents(docs) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py index 51e21b10..336eaa48 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -1,5 +1,4 @@ -from typing import List, Dict, Optional -from langchain.embeddings.base import Embeddings +from typing import List, Dict from langchain.schema import Document from langchain.vectorstores import Zilliz from chatchat.configs import kbs_config @@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV score_threshold_process from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings +from chatchat.server.file_rag.utils import get_Retriever class ZillizKBService(KBService): @@ -60,10 +60,13 @@ class ZillizKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_zilliz() - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k) - return score_threshold_process(score_threshold, top_k, docs) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.zilliz, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: for doc in docs: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py index c5dd442b..a03393dd 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py @@ -10,7 +10,7 @@ from chatchat.configs import ( TEXT_SPLITTER_NAME, ) import importlib -from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance +from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance import langchain_community.document_loaders from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter diff --git a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index cdd40aa6..4ec3a80b 100644 --- a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -24,28 +24,28 @@ chat_box = ChatBox( def save_session(): - '''save session state to chat context''' + """save session state to chat context""" chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"]) def restore_session(): - '''restore sesstion state from chat context''' + """restore sesstion state from chat context""" chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"]) def rerun(): - ''' + """ save chat context before rerun - ''' + """ save_session() st.rerun() def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: - ''' + """ 返回消息历史。 content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要 - ''' + """ def filter(msg): content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]] @@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) -> @st.cache_data def upload_temp_docs(files, _api: ApiRequest) -> str: - ''' + """ 将文件上传到临时目录,用于文件对话 返回临时向量库ID - ''' + """ return _api.upload_temp_docs(files).get("data", {}).get("id") @@ -157,11 +157,13 @@ def dialogue_page( tools = list_tools(api) tool_names = ["None"] + list(tools) if use_agent: - # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools") + # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", + # check_all=True, key="selected_tools") selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], key="selected_tools") else: - # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool") + # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", + # key="selected_tool") selected_tool = st.selectbox("选择工具", tool_names, format_func=lambda x: tools.get(x, {"title": "None"})["title"], key="selected_tool") @@ -338,7 +340,7 @@ def dialogue_page( elif d.status == AgentStatus.agent_finish: text = d.choices[0].delta.content or "" chat_box.update_msg(text.replace("\n", "\n\n")) - elif d.status == None: # not agent chat + elif d.status is None: # not agent chat if getattr(d, "is_ref", False): chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料")) diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index e105db24..136b8b1b 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -29,6 +29,8 @@ unstructured = "~0.11.0" python-magic-bin = {version = "*", platform = "win32"} SQLAlchemy = "~2.0.25" faiss-cpu = "~1.7.4" +cutword = "0.1.0" +rank_bm25 = "0.2.2" # accelerate = "~0.24.1" # spacy = "~3.7.2" PyMuPDF = "~1.23.16"