Merge branch 'dev' into dev_config_init

This commit is contained in:
glide-the 2024-06-10 22:00:22 +08:00
commit 72b1cab89a
37 changed files with 211 additions and 65 deletions

View File

@ -1,11 +1,13 @@
import os
from pathlib import Path
# chatchat 项目根目录
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
import sys
sys.path.append(str(Path(__file__).parent))
from _basic_config import config_workspace
# 用户数据根目录
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
DATA_PATH = config_workspace.get_config().DATA_PATH
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples"

View File

@ -149,7 +149,7 @@ TOOL_CONFIG = {
"search_local_knowledgebase": {
"use": False,
"top_k": 3,
"score_threshold": 1,
"score_threshold": 1.0,
"conclude_prompt": {
"with_result":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"'

View File

@ -1,13 +1,14 @@
from urllib.parse import urlencode
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool, BaseToolOutput
from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput
from chatchat.server.knowledge_base.kb_api import list_kbs
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
from chatchat.configs import KB_INFO
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on "
"this knowledge use this tool. The 'database' should be one of the above [{key}].")
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
@ -49,7 +50,7 @@ def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
query: str = Field(description="Query for Knowledge Search"),
):
''''''
""""""
tool_config = get_tool_config("search_local_knowledgebase")
ret = search_knowledgebase(query=query, database=database, config=tool_config)
return KBToolOutput(ret, database=database)

View File

@ -63,10 +63,10 @@ def upload_temp_docs(
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
) -> BaseResponse:
'''
"""
将文件保存到临时目录并进行向量化
返回临时目录名称作为ID同时也是临时向量库的ID
'''
"""
if prev_id is not None:
memo_faiss_pool.pop(prev_id)
@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
docs = [x[0] for x in docs]
context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: ## 如果没有找到相关文档使用Empty模板
if len(docs) == 0: # 如果没有找到相关文档使用Empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)

View File

@ -1,6 +1,6 @@
from typing import List
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
from chatchat.server.document_loaders.ocr import get_ocr
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
class RapidOCRLoader(UnstructuredFileLoader):

View File

@ -4,7 +4,7 @@ import cv2
from PIL import Image
import numpy as np
from chatchat.configs import PDF_OCR_THRESHOLD
from chatchat.server.document_loaders.ocr import get_ocr
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
import tqdm

View File

@ -0,0 +1,3 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService

View File

@ -0,0 +1,24 @@
from langchain.vectorstores import VectorStore
from abc import ABCMeta, abstractmethod
class BaseRetrieverService(metaclass=ABCMeta):
def __init__(self, **kwargs):
self.do_init(**kwargs)
@abstractmethod
def do_init(self, **kwargs):
pass
@abstractmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
pass
@abstractmethod
def get_relevant_documents(self, query: str):
pass

View File

@ -0,0 +1,48 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
class EnsembleRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
faiss_retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
# TODO: 换个不用torch的实现方式
# from cutword.cutword import Cutter
import jieba
# cutter = Cutter()
docs = list(vectorstore.docstore._dict.values())
bm25_retriever = BM25Retriever.from_documents(
docs,
preprocess_func=jieba.lcut_for_search,
)
bm25_retriever.k = top_k
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)
return EnsembleRetrieverService(retriever=ensemble_retriever)
def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,33 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
class VectorstoreRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
return VectorstoreRetrieverService(retriever=retriever)
def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,13 @@
from chatchat.server.file_rag.retrievers import (
BaseRetrieverService,
VectorstoreRetrieverService,
EnsembleRetrieverService,
)
Retrivals = {
"vectorstore": VectorstoreRetrieverService,
"ensemble": EnsembleRetrieverService,
}
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
return Retrivals[type]

View File

@ -110,7 +110,8 @@ class KBFaissPool(_FaissPool):
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = get_Embeddings(embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,
allow_dangerous_deserialization=True)
elif create:
# create an empty vector store
if not os.path.exists(vs_path):

View File

@ -28,7 +28,7 @@ def search_docs(
description="知识库匹配相关度阈值取值范围在0-1之间"
"SCORE越小相关度越高"
"取到1相当于不筛选建议设置在0.5左右",
ge=0, le=1),
ge=0.0, le=1.0),
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
) -> List[Dict]:
@ -37,7 +37,8 @@ def search_docs(
if kb is not None:
if query:
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
for d in data:

View File

@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from chatchat.server.utils import get_Embeddings
from chatchat.server.file_rag.utils import get_Retriever
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
@ -75,10 +76,13 @@ class ChromaKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
return _results_to_docs_and_scores(query_result)
retriever = get_Retriever("vectorstore").from_vectorstore(
self.collection,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
doc_infos = []

View File

@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings
from elasticsearch import Elasticsearch, BadRequestError
from chatchat.configs import kbs_config, KB_ROOT_PATH
from chatchat.server.file_rag.utils import get_Retriever
import logging
@ -107,8 +108,12 @@ class ESKBService(KBService):
def do_search(self, query:str, top_k: int, score_threshold: float):
# 文本相似性检索
docs = self.db.similarity_search_with_score(query=query,
k=top_k)
retriever = get_Retriever("vectorstore").from_vectorstore(
self.db,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:

View File

@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from chatchat.server.utils import get_Embeddings
from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple
from typing import List, Dict, Tuple
from chatchat.server.file_rag.utils import get_Retriever
class FaissKBService(KBService):
@ -62,10 +62,13 @@ class FaissKBService(KBService):
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
) -> List[Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
retriever = get_Retriever("ensemble").from_vectorstore(
vs,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self,

View File

@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
score_threshold_process
from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings
from chatchat.server.file_rag.utils import get_Retriever
class MilvusKBService(KBService):
@ -67,10 +67,16 @@ class MilvusKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus()
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
# embed_func = get_Embeddings(self.embed_model)
# embeddings = embed_func.embed_query(query)
# docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
retriever = get_Retriever("vectorstore").from_vectorstore(
self.milvus,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:

View File

@ -15,6 +15,7 @@ import shutil
import sqlalchemy
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
from chatchat.server.file_rag.utils import get_Retriever
class PGKBService(KBService):
@ -60,10 +61,13 @@ class PGKBService(KBService):
shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float):
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
retriever = get_Retriever("vectorstore").from_vectorstore(
self.pg_vector,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs)

View File

@ -1,5 +1,4 @@
from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from typing import List, Dict
from langchain.schema import Document
from langchain.vectorstores import Zilliz
from chatchat.configs import kbs_config
@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
score_threshold_process
from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings
from chatchat.server.file_rag.utils import get_Retriever
class ZillizKBService(KBService):
@ -60,10 +60,13 @@ class ZillizKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_zilliz()
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
retriever = get_Retriever("vectorstore").from_vectorstore(
self.zilliz,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:

View File

@ -10,7 +10,7 @@ from chatchat.configs import (
TEXT_SPLITTER_NAME,
)
import importlib
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance
from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain_community.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter

View File

@ -24,28 +24,28 @@ chat_box = ChatBox(
def save_session():
'''save session state to chat context'''
"""save session state to chat context"""
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def restore_session():
'''restore sesstion state from chat context'''
"""restore sesstion state from chat context"""
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def rerun():
'''
"""
save chat context before rerun
'''
"""
save_session()
st.rerun()
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
'''
"""
返回消息历史
content_in_expander控制是否返回expander元素中的内容一般导出的时候可以选上传入LLM的history不需要
'''
"""
def filter(msg):
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
@st.cache_data
def upload_temp_docs(files, _api: ApiRequest) -> str:
'''
"""
将文件上传到临时目录用于文件对话
返回临时向量库ID
'''
"""
return _api.upload_temp_docs(files).get("data", {}).get("id")
@ -157,11 +157,13 @@ def dialogue_page(
tools = list_tools(api)
tool_names = ["None"] + list(tools)
if use_agent:
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools")
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# check_all=True, key="selected_tools")
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
key="selected_tools")
else:
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool")
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# key="selected_tool")
selected_tool = st.selectbox("选择工具", tool_names,
format_func=lambda x: tools.get(x, {"title": "None"})["title"],
key="selected_tool")
@ -338,7 +340,7 @@ def dialogue_page(
elif d.status == AgentStatus.agent_finish:
text = d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"))
elif d.status == None: # not agent chat
elif d.status is None: # not agent chat
if getattr(d, "is_ref", False):
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
title="参考资料"))

View File

@ -1,9 +1,6 @@
[virtualenvs]
in-project = true
[installer]
modern-installation = false
[plugins]
[plugins.pypi_mirror]
url = "https://pypi.tuna.tsinghua.edu.cn/simple"

View File

@ -29,6 +29,9 @@ unstructured = "~0.11.0"
python-magic-bin = {version = "*", platform = "win32"}
SQLAlchemy = "~2.0.25"
faiss-cpu = "~1.7.4"
#cutword = "0.1.0"
jieba = "0.42.1"
rank_bm25 = "0.2.2"
# accelerate = "~0.24.1"
# spacy = "~3.7.2"
PyMuPDF = "~1.23.16"
@ -264,4 +267,4 @@ dotenv = "dotenv:plugin"
[[tool.poetry.source]]
name = "tsinghua"
url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
priority = "default"
priority = "primary"

View File

@ -1,9 +1,6 @@
[virtualenvs]
in-project = true
[installer]
modern-installation = false
[plugins]
[plugins.pypi_mirror]
url = "https://pypi.tuna.tsinghua.edu.cn/simple"

View File

@ -1,9 +1,5 @@
[virtualenvs]
in-project = true
[installer]
modern-installation = false
[plugins]
[plugins.pypi_mirror]
url = "https://pypi.tuna.tsinghua.edu.cn/simple"

View File

@ -137,7 +137,7 @@ model_format = None
model_quant = None
if model_type == "LLM":
cur_family = xf_llm.LLMFamilyV1.model_validate(cur_reg)
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
cur_spec = None
model_formats = []
for spec in cur_reg["model_specs"]: