mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-02 12:46:56 +08:00
Feature(File RAG): add file_rag in chatchat-server, add ensemble retriever and vectorstore retriever.
This commit is contained in:
parent
b110fcd01b
commit
b56283eb01
@ -1,13 +1,14 @@
|
||||
from urllib.parse import urlencode
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput
|
||||
from chatchat.server.knowledge_base.kb_api import list_kbs
|
||||
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
|
||||
from chatchat.configs import KB_INFO
|
||||
|
||||
|
||||
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
||||
template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on "
|
||||
"this knowledge use this tool. The 'database' should be one of the above [{key}].")
|
||||
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
||||
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
||||
|
||||
@ -49,7 +50,7 @@ def search_local_knowledgebase(
|
||||
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
|
||||
query: str = Field(description="Query for Knowledge Search"),
|
||||
):
|
||||
''''''
|
||||
""""""
|
||||
tool_config = get_tool_config("search_local_knowledgebase")
|
||||
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
||||
return KBToolOutput(ret, database=database)
|
||||
|
||||
@ -63,10 +63,10 @@ def upload_temp_docs(
|
||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
"""
|
||||
将文件保存到临时目录,并进行向量化。
|
||||
返回临时目录名称作为ID,同时也是临时向量库的ID。
|
||||
'''
|
||||
"""
|
||||
if prev_id is not None:
|
||||
memo_faiss_pool.pop(prev_id)
|
||||
|
||||
@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
docs = [x[0] for x in docs]
|
||||
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
|
||||
if len(docs) == 0: # 如果没有找到相关文档,使用Empty模板
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||
else:
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from chatchat.server.document_loaders.ocr import get_ocr
|
||||
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
|
||||
|
||||
|
||||
class RapidOCRLoader(UnstructuredFileLoader):
|
||||
@ -4,7 +4,7 @@ import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from chatchat.configs import PDF_OCR_THRESHOLD
|
||||
from chatchat.server.document_loaders.ocr import get_ocr
|
||||
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
|
||||
import tqdm
|
||||
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
|
||||
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService
|
||||
@ -0,0 +1,24 @@
|
||||
from langchain.vectorstores import VectorStore
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseRetrieverService(metaclass=ABCMeta):
|
||||
def __init__(self, **kwargs):
|
||||
self.do_init(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def do_init(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def from_vectorstore(
|
||||
vectorstore: VectorStore,
|
||||
top_k: int,
|
||||
score_threshold: int or float,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_relevant_documents(self, query: str):
|
||||
pass
|
||||
@ -0,0 +1,47 @@
|
||||
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain.retrievers import EnsembleRetriever
|
||||
|
||||
|
||||
class EnsembleRetrieverService(BaseRetrieverService):
|
||||
def do_init(
|
||||
self,
|
||||
retriever: BaseRetriever = None,
|
||||
top_k: int = 5,
|
||||
):
|
||||
self.vs = None
|
||||
self.top_k = top_k
|
||||
self.retriever = retriever
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_vectorstore(
|
||||
vectorstore: VectorStore,
|
||||
top_k: int,
|
||||
score_threshold: int or float,
|
||||
):
|
||||
faiss_retriever = vectorstore.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={
|
||||
"score_threshold": score_threshold,
|
||||
"k": top_k
|
||||
}
|
||||
)
|
||||
# TODO: 换个不用torch的实现方式
|
||||
from cutword.cutword import Cutter
|
||||
cutter = Cutter()
|
||||
docs = list(vectorstore.docstore._dict.values())
|
||||
bm25_retriever = BM25Retriever.from_documents(
|
||||
docs,
|
||||
preprocess_func=cutter.cutword
|
||||
)
|
||||
bm25_retriever.k = top_k
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
|
||||
)
|
||||
return EnsembleRetrieverService(retriever=ensemble_retriever)
|
||||
|
||||
def get_relevant_documents(self, query: str):
|
||||
return self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||
@ -0,0 +1,33 @@
|
||||
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class VectorstoreRetrieverService(BaseRetrieverService):
|
||||
def do_init(
|
||||
self,
|
||||
retriever: BaseRetriever = None,
|
||||
top_k: int = 5,
|
||||
):
|
||||
self.vs = None
|
||||
self.top_k = top_k
|
||||
self.retriever = retriever
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_vectorstore(
|
||||
vectorstore: VectorStore,
|
||||
top_k: int,
|
||||
score_threshold: int or float,
|
||||
):
|
||||
retriever = vectorstore.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={
|
||||
"score_threshold": score_threshold,
|
||||
"k": top_k
|
||||
}
|
||||
)
|
||||
return VectorstoreRetrieverService(retriever=retriever)
|
||||
|
||||
def get_relevant_documents(self, query: str):
|
||||
return self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||
13
libs/chatchat-server/chatchat/server/file_rag/utils.py
Normal file
13
libs/chatchat-server/chatchat/server/file_rag/utils.py
Normal file
@ -0,0 +1,13 @@
|
||||
from chatchat.server.file_rag.retrievers import (
|
||||
BaseRetrieverService,
|
||||
VectorstoreRetrieverService,
|
||||
EnsembleRetrieverService,
|
||||
)
|
||||
|
||||
Retrivals = {
|
||||
"vectorstore": VectorstoreRetrieverService,
|
||||
"ensemble": EnsembleRetrieverService,
|
||||
}
|
||||
|
||||
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
|
||||
return Retrivals[type]
|
||||
@ -37,7 +37,8 @@ def search_docs(
|
||||
if kb is not None:
|
||||
if query:
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||
data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs]
|
||||
elif file_name or metadata:
|
||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||
for d in data:
|
||||
|
||||
@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
|
||||
@ -75,10 +76,13 @@ class ChromaKBService(KBService):
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
|
||||
Tuple[Document, float]]:
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
|
||||
return _results_to_docs_and_scores(query_result)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.collection,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
doc_infos = []
|
||||
|
||||
@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from elasticsearch import Elasticsearch, BadRequestError
|
||||
from chatchat.configs import kbs_config, KB_ROOT_PATH
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
import logging
|
||||
|
||||
@ -107,8 +108,12 @@ class ESKBService(KBService):
|
||||
|
||||
def do_search(self, query:str, top_k: int, score_threshold: float):
|
||||
# 文本相似性检索
|
||||
docs = self.db.similarity_search_with_score(query=query,
|
||||
k=top_k)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.db,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
|
||||
@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from typing import List, Dict, Tuple
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
@ -62,10 +62,13 @@ class FaissKBService(KBService):
|
||||
top_k: int,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
retriever = get_Retriever("ensemble").from_vectorstore(
|
||||
vs,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
|
||||
@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
|
||||
score_threshold_process
|
||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
class MilvusKBService(KBService):
|
||||
@ -67,10 +67,16 @@ class MilvusKBService(KBService):
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
self._load_milvus()
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
# embed_func = get_Embeddings(self.embed_model)
|
||||
# embeddings = embed_func.embed_query(query)
|
||||
# docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.milvus,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
for doc in docs:
|
||||
|
||||
@ -15,6 +15,7 @@ import shutil
|
||||
import sqlalchemy
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
class PGKBService(KBService):
|
||||
@ -60,10 +61,13 @@ class PGKBService(KBService):
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.pg_vector,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
ids = self.pg_vector.add_documents(docs)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import List, Dict, Optional
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from typing import List, Dict
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Zilliz
|
||||
from chatchat.configs import kbs_config
|
||||
@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
|
||||
score_threshold_process
|
||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
class ZillizKBService(KBService):
|
||||
@ -60,10 +60,13 @@ class ZillizKBService(KBService):
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
self._load_zilliz()
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.zilliz,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
for doc in docs:
|
||||
|
||||
@ -10,7 +10,7 @@ from chatchat.configs import (
|
||||
TEXT_SPLITTER_NAME,
|
||||
)
|
||||
import importlib
|
||||
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||
from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||
import langchain_community.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
|
||||
|
||||
@ -24,28 +24,28 @@ chat_box = ChatBox(
|
||||
|
||||
|
||||
def save_session():
|
||||
'''save session state to chat context'''
|
||||
"""save session state to chat context"""
|
||||
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||
|
||||
|
||||
def restore_session():
|
||||
'''restore sesstion state from chat context'''
|
||||
"""restore sesstion state from chat context"""
|
||||
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||
|
||||
|
||||
def rerun():
|
||||
'''
|
||||
"""
|
||||
save chat context before rerun
|
||||
'''
|
||||
"""
|
||||
save_session()
|
||||
st.rerun()
|
||||
|
||||
|
||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||
'''
|
||||
"""
|
||||
返回消息历史。
|
||||
content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要
|
||||
'''
|
||||
"""
|
||||
|
||||
def filter(msg):
|
||||
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
|
||||
@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
||||
|
||||
@st.cache_data
|
||||
def upload_temp_docs(files, _api: ApiRequest) -> str:
|
||||
'''
|
||||
"""
|
||||
将文件上传到临时目录,用于文件对话
|
||||
返回临时向量库ID
|
||||
'''
|
||||
"""
|
||||
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
||||
|
||||
|
||||
@ -157,11 +157,13 @@ def dialogue_page(
|
||||
tools = list_tools(api)
|
||||
tool_names = ["None"] + list(tools)
|
||||
if use_agent:
|
||||
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools")
|
||||
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
|
||||
# check_all=True, key="selected_tools")
|
||||
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
|
||||
key="selected_tools")
|
||||
else:
|
||||
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool")
|
||||
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
|
||||
# key="selected_tool")
|
||||
selected_tool = st.selectbox("选择工具", tool_names,
|
||||
format_func=lambda x: tools.get(x, {"title": "None"})["title"],
|
||||
key="selected_tool")
|
||||
@ -338,7 +340,7 @@ def dialogue_page(
|
||||
elif d.status == AgentStatus.agent_finish:
|
||||
text = d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"))
|
||||
elif d.status == None: # not agent chat
|
||||
elif d.status is None: # not agent chat
|
||||
if getattr(d, "is_ref", False):
|
||||
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
|
||||
title="参考资料"))
|
||||
|
||||
@ -29,6 +29,8 @@ unstructured = "~0.11.0"
|
||||
python-magic-bin = {version = "*", platform = "win32"}
|
||||
SQLAlchemy = "~2.0.25"
|
||||
faiss-cpu = "~1.7.4"
|
||||
cutword = "0.1.0"
|
||||
rank_bm25 = "0.2.2"
|
||||
# accelerate = "~0.24.1"
|
||||
# spacy = "~3.7.2"
|
||||
PyMuPDF = "~1.23.16"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user