Feature(File RAG): add file_rag in chatchat-server, add ensemble retriever and vectorstore retriever.

This commit is contained in:
imClumsyPanda 2024-06-10 00:36:03 +08:00
parent b110fcd01b
commit b56283eb01
30 changed files with 198 additions and 47 deletions

View File

@ -1,13 +1,14 @@
from urllib.parse import urlencode from urllib.parse import urlencode
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool, BaseToolOutput from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput
from chatchat.server.knowledge_base.kb_api import list_kbs from chatchat.server.knowledge_base.kb_api import list_kbs
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
from chatchat.configs import KB_INFO from chatchat.configs import KB_INFO
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on "
"this knowledge use this tool. The 'database' should be one of the above [{key}].")
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples") template_knowledge = template.format(KB_info=KB_info_str, key="samples")
@ -49,7 +50,7 @@ def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]), database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
query: str = Field(description="Query for Knowledge Search"), query: str = Field(description="Query for Knowledge Search"),
): ):
'''''' """"""
tool_config = get_tool_config("search_local_knowledgebase") tool_config = get_tool_config("search_local_knowledgebase")
ret = search_knowledgebase(query=query, database=database, config=tool_config) ret = search_knowledgebase(query=query, database=database, config=tool_config)
return KBToolOutput(ret, database=database) return KBToolOutput(ret, database=database)

View File

@ -63,10 +63,10 @@ def upload_temp_docs(
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
) -> BaseResponse: ) -> BaseResponse:
''' """
将文件保存到临时目录并进行向量化 将文件保存到临时目录并进行向量化
返回临时目录名称作为ID同时也是临时向量库的ID 返回临时目录名称作为ID同时也是临时向量库的ID
''' """
if prev_id is not None: if prev_id is not None:
memo_faiss_pool.pop(prev_id) memo_faiss_pool.pop(prev_id)
@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
docs = [x[0] for x in docs] docs = [x[0] for x in docs]
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: ## 如果没有找到相关文档使用Empty模板 if len(docs) == 0: # 如果没有找到相关文档使用Empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty") prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else: else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)

View File

@ -1,6 +1,6 @@
from typing import List from typing import List
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
from chatchat.server.document_loaders.ocr import get_ocr from chatchat.server.file_rag.document_loaders.ocr import get_ocr
class RapidOCRLoader(UnstructuredFileLoader): class RapidOCRLoader(UnstructuredFileLoader):

View File

@ -4,7 +4,7 @@ import cv2
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from chatchat.configs import PDF_OCR_THRESHOLD from chatchat.configs import PDF_OCR_THRESHOLD
from chatchat.server.document_loaders.ocr import get_ocr from chatchat.server.file_rag.document_loaders.ocr import get_ocr
import tqdm import tqdm

View File

@ -0,0 +1,3 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService

View File

@ -0,0 +1,24 @@
from langchain.vectorstores import VectorStore
from abc import ABCMeta, abstractmethod
class BaseRetrieverService(metaclass=ABCMeta):
def __init__(self, **kwargs):
self.do_init(**kwargs)
@abstractmethod
def do_init(self, **kwargs):
pass
@abstractmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
pass
@abstractmethod
def get_relevant_documents(self, query: str):
pass

View File

@ -0,0 +1,47 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
class EnsembleRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
faiss_retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
# TODO: 换个不用torch的实现方式
from cutword.cutword import Cutter
cutter = Cutter()
docs = list(vectorstore.docstore._dict.values())
bm25_retriever = BM25Retriever.from_documents(
docs,
preprocess_func=cutter.cutword
)
bm25_retriever.k = top_k
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)
return EnsembleRetrieverService(retriever=ensemble_retriever)
def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,33 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
class VectorstoreRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
return VectorstoreRetrieverService(retriever=retriever)
def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,13 @@
from chatchat.server.file_rag.retrievers import (
BaseRetrieverService,
VectorstoreRetrieverService,
EnsembleRetrieverService,
)
Retrivals = {
"vectorstore": VectorstoreRetrieverService,
"ensemble": EnsembleRetrieverService,
}
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
return Retrivals[type]

View File

@ -37,7 +37,8 @@ def search_docs(
if kb is not None: if kb is not None:
if query: if query:
docs = kb.search_docs(query, top_k, score_threshold) docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs]
elif file_name or metadata: elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata) data = kb.list_docs(file_name=file_name, metadata=metadata)
for d in data: for d in data:

View File

@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from chatchat.server.file_rag.utils import get_Retriever
def _get_result_to_documents(get_result: GetResult) -> List[Document]: def _get_result_to_documents(get_result: GetResult) -> List[Document]:
@ -75,10 +76,13 @@ class ChromaKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[ def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
Tuple[Document, float]]: Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model) retriever = get_Retriever("vectorstore").from_vectorstore(
embeddings = embed_func.embed_query(query) self.collection,
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k) top_k=top_k,
return _results_to_docs_and_scores(query_result) score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
doc_infos = [] doc_infos = []

View File

@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from elasticsearch import Elasticsearch, BadRequestError from elasticsearch import Elasticsearch, BadRequestError
from chatchat.configs import kbs_config, KB_ROOT_PATH from chatchat.configs import kbs_config, KB_ROOT_PATH
from chatchat.server.file_rag.utils import get_Retriever
import logging import logging
@ -107,8 +108,12 @@ class ESKBService(KBService):
def do_search(self, query:str, top_k: int, score_threshold: float): def do_search(self, query:str, top_k: int, score_threshold: float):
# 文本相似性检索 # 文本相似性检索
docs = self.db.similarity_search_with_score(query=query, retriever = get_Retriever("vectorstore").from_vectorstore(
k=top_k) self.db,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:

View File

@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from chatchat.server.utils import get_Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Tuple
from chatchat.server.file_rag.utils import get_Retriever
class FaissKBService(KBService): class FaissKBService(KBService):
@ -62,10 +62,13 @@ class FaissKBService(KBService):
top_k: int, top_k: int,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) retriever = get_Retriever("ensemble").from_vectorstore(
vs,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs return docs
def do_add_doc(self, def do_add_doc(self,

View File

@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
score_threshold_process score_threshold_process
from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.file_rag.utils import get_Retriever
class MilvusKBService(KBService): class MilvusKBService(KBService):
@ -67,10 +67,16 @@ class MilvusKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus() self._load_milvus()
embed_func = get_Embeddings(self.embed_model) # embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query) # embeddings = embed_func.embed_query(query)
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) # docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs) retriever = get_Retriever("vectorstore").from_vectorstore(
self.milvus,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:

View File

@ -15,6 +15,7 @@ import shutil
import sqlalchemy import sqlalchemy
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from chatchat.server.file_rag.utils import get_Retriever
class PGKBService(KBService): class PGKBService(KBService):
@ -60,10 +61,13 @@ class PGKBService(KBService):
shutil.rmtree(self.kb_path) shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
embed_func = get_Embeddings(self.embed_model) retriever = get_Retriever("vectorstore").from_vectorstore(
embeddings = embed_func.embed_query(query) self.pg_vector,
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) top_k=top_k,
return score_threshold_process(score_threshold, top_k, docs) score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs) ids = self.pg_vector.add_documents(docs)

View File

@ -1,5 +1,4 @@
from typing import List, Dict, Optional from typing import List, Dict
from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.vectorstores import Zilliz from langchain.vectorstores import Zilliz
from chatchat.configs import kbs_config from chatchat.configs import kbs_config
@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
score_threshold_process score_threshold_process
from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from chatchat.server.file_rag.utils import get_Retriever
class ZillizKBService(KBService): class ZillizKBService(KBService):
@ -60,10 +60,13 @@ class ZillizKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_zilliz() self._load_zilliz()
embed_func = get_Embeddings(self.embed_model) retriever = get_Retriever("vectorstore").from_vectorstore(
embeddings = embed_func.embed_query(query) self.zilliz,
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k) top_k=top_k,
return score_threshold_process(score_threshold, top_k, docs) score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:

View File

@ -10,7 +10,7 @@ from chatchat.configs import (
TEXT_SPLITTER_NAME, TEXT_SPLITTER_NAME,
) )
import importlib import importlib
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain_community.document_loaders import langchain_community.document_loaders
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter

View File

@ -24,28 +24,28 @@ chat_box = ChatBox(
def save_session(): def save_session():
'''save session state to chat context''' """save session state to chat context"""
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"]) chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def restore_session(): def restore_session():
'''restore sesstion state from chat context''' """restore sesstion state from chat context"""
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"]) chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def rerun(): def rerun():
''' """
save chat context before rerun save chat context before rerun
''' """
save_session() save_session()
st.rerun() st.rerun()
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
''' """
返回消息历史 返回消息历史
content_in_expander控制是否返回expander元素中的内容一般导出的时候可以选上传入LLM的history不需要 content_in_expander控制是否返回expander元素中的内容一般导出的时候可以选上传入LLM的history不需要
''' """
def filter(msg): def filter(msg):
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]] content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
@st.cache_data @st.cache_data
def upload_temp_docs(files, _api: ApiRequest) -> str: def upload_temp_docs(files, _api: ApiRequest) -> str:
''' """
将文件上传到临时目录用于文件对话 将文件上传到临时目录用于文件对话
返回临时向量库ID 返回临时向量库ID
''' """
return _api.upload_temp_docs(files).get("data", {}).get("id") return _api.upload_temp_docs(files).get("data", {}).get("id")
@ -157,11 +157,13 @@ def dialogue_page(
tools = list_tools(api) tools = list_tools(api)
tool_names = ["None"] + list(tools) tool_names = ["None"] + list(tools)
if use_agent: if use_agent:
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools") # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# check_all=True, key="selected_tools")
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
key="selected_tools") key="selected_tools")
else: else:
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool") # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# key="selected_tool")
selected_tool = st.selectbox("选择工具", tool_names, selected_tool = st.selectbox("选择工具", tool_names,
format_func=lambda x: tools.get(x, {"title": "None"})["title"], format_func=lambda x: tools.get(x, {"title": "None"})["title"],
key="selected_tool") key="selected_tool")
@ -338,7 +340,7 @@ def dialogue_page(
elif d.status == AgentStatus.agent_finish: elif d.status == AgentStatus.agent_finish:
text = d.choices[0].delta.content or "" text = d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n")) chat_box.update_msg(text.replace("\n", "\n\n"))
elif d.status == None: # not agent chat elif d.status is None: # not agent chat
if getattr(d, "is_ref", False): if getattr(d, "is_ref", False):
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
title="参考资料")) title="参考资料"))

View File

@ -29,6 +29,8 @@ unstructured = "~0.11.0"
python-magic-bin = {version = "*", platform = "win32"} python-magic-bin = {version = "*", platform = "win32"}
SQLAlchemy = "~2.0.25" SQLAlchemy = "~2.0.25"
faiss-cpu = "~1.7.4" faiss-cpu = "~1.7.4"
cutword = "0.1.0"
rank_bm25 = "0.2.2"
# accelerate = "~0.24.1" # accelerate = "~0.24.1"
# spacy = "~3.7.2" # spacy = "~3.7.2"
PyMuPDF = "~1.23.16" PyMuPDF = "~1.23.16"