配置中心知识库信息业务代码

This commit is contained in:
glide-the 2024-06-11 17:37:02 +08:00
parent 35c2f596f6
commit 4f9d63d9f4
3 changed files with 424 additions and 157 deletions

View File

@ -203,164 +203,196 @@ def _import_base_temp_dir() -> Any:
return config_basic_workspace.get_config().BASE_TEMP_DIR return config_basic_workspace.get_config().BASE_TEMP_DIR
def _import_ConfigKb() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigKb = load_mod(basic_config_load.get("module"), "ConfigKb")
return ConfigKb
def _import_ConfigKbFactory() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigKbFactory = load_mod(basic_config_load.get("module"), "ConfigKbFactory")
return ConfigKbFactory
def _import_ConfigKbWorkSpace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigKbWorkSpace = load_mod(basic_config_load.get("module"), "ConfigKbWorkSpace")
return ConfigKbWorkSpace
def _import_config_kb_workspace() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return config_kb_workspace
def _import_default_knowledge_base() -> Any: def _import_default_knowledge_base() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
DEFAULT_KNOWLEDGE_BASE = load_mod(kb_config_load.get("module"), "DEFAULT_KNOWLEDGE_BASE") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return DEFAULT_KNOWLEDGE_BASE return config_kb_workspace.get_config().DEFAULT_KNOWLEDGE_BASE
def _import_default_vs_type() -> Any: def _import_default_vs_type() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
DEFAULT_VS_TYPE = load_mod(kb_config_load.get("module"), "DEFAULT_VS_TYPE") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return DEFAULT_VS_TYPE return config_kb_workspace.get_config().DEFAULT_VS_TYPE
def _import_cached_vs_num() -> Any: def _import_cached_vs_num() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
CACHED_VS_NUM = load_mod(kb_config_load.get("module"), "CACHED_VS_NUM") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return CACHED_VS_NUM return config_kb_workspace.get_config().CACHED_VS_NUM
def _import_cached_memo_vs_num() -> Any: def _import_cached_memo_vs_num() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
CACHED_MEMO_VS_NUM = load_mod(kb_config_load.get("module"), "CACHED_MEMO_VS_NUM") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return CACHED_MEMO_VS_NUM return config_kb_workspace.get_config().CACHED_MEMO_VS_NUM
def _import_chunk_size() -> Any: def _import_chunk_size() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
CHUNK_SIZE = load_mod(kb_config_load.get("module"), "CHUNK_SIZE") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return CHUNK_SIZE return config_kb_workspace.get_config().CHUNK_SIZE
def _import_overlap_size() -> Any: def _import_overlap_size() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
OVERLAP_SIZE = load_mod(kb_config_load.get("module"), "OVERLAP_SIZE") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return OVERLAP_SIZE return config_kb_workspace.get_config().OVERLAP_SIZE
def _import_vector_search_top_k() -> Any: def _import_vector_search_top_k() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
VECTOR_SEARCH_TOP_K = load_mod(kb_config_load.get("module"), "VECTOR_SEARCH_TOP_K") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return VECTOR_SEARCH_TOP_K return config_kb_workspace.get_config().VECTOR_SEARCH_TOP_K
def _import_score_threshold() -> Any: def _import_score_threshold() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
SCORE_THRESHOLD = load_mod(kb_config_load.get("module"), "SCORE_THRESHOLD") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return SCORE_THRESHOLD return config_kb_workspace.get_config().SCORE_THRESHOLD
def _import_default_search_engine() -> Any: def _import_default_search_engine() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
DEFAULT_SEARCH_ENGINE = load_mod(kb_config_load.get("module"), "DEFAULT_SEARCH_ENGINE") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return DEFAULT_SEARCH_ENGINE return config_kb_workspace.get_config().DEFAULT_SEARCH_ENGINE
def _import_search_engine_top_k() -> Any: def _import_search_engine_top_k() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
SEARCH_ENGINE_TOP_K = load_mod(kb_config_load.get("module"), "SEARCH_ENGINE_TOP_K") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return SEARCH_ENGINE_TOP_K return config_kb_workspace.get_config().SEARCH_ENGINE_TOP_K
def _import_zh_title_enhance() -> Any: def _import_zh_title_enhance() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
ZH_TITLE_ENHANCE = load_mod(kb_config_load.get("module"), "ZH_TITLE_ENHANCE") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return ZH_TITLE_ENHANCE return config_kb_workspace.get_config().ZH_TITLE_ENHANCE
def _import_pdf_ocr_threshold() -> Any: def _import_pdf_ocr_threshold() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
PDF_OCR_THRESHOLD = load_mod(kb_config_load.get("module"), "PDF_OCR_THRESHOLD") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return PDF_OCR_THRESHOLD return config_kb_workspace.get_config().PDF_OCR_THRESHOLD
def _import_kb_info() -> Any: def _import_kb_info() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
KB_INFO = load_mod(kb_config_load.get("module"), "KB_INFO") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return KB_INFO return config_kb_workspace.get_config().KB_INFO
def _import_kb_root_path() -> Any: def _import_kb_root_path() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
KB_ROOT_PATH = load_mod(kb_config_load.get("module"), "KB_ROOT_PATH") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return KB_ROOT_PATH return config_kb_workspace.get_config().KB_ROOT_PATH
def _import_db_root_path() -> Any: def _import_db_root_path() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
DB_ROOT_PATH = load_mod(kb_config_load.get("module"), "DB_ROOT_PATH") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return DB_ROOT_PATH return config_kb_workspace.get_config().DB_ROOT_PATH
def _import_sqlalchemy_database_uri() -> Any: def _import_sqlalchemy_database_uri() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
SQLALCHEMY_DATABASE_URI = load_mod(kb_config_load.get("module"), "SQLALCHEMY_DATABASE_URI") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return SQLALCHEMY_DATABASE_URI return config_kb_workspace.get_config().SQLALCHEMY_DATABASE_URI
def _import_kbs_config() -> Any: def _import_kbs_config() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
kbs_config = load_mod(kb_config_load.get("module"), "kbs_config") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return kbs_config return config_kb_workspace.get_config().kbs_config
def _import_text_splitter_dict() -> Any: def _import_text_splitter_dict() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
text_splitter_dict = load_mod(kb_config_load.get("module"), "text_splitter_dict") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return text_splitter_dict return config_kb_workspace.get_config().TEXT_SPLITTER_DICT
def _import_text_splitter_name() -> Any: def _import_text_splitter_name() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
TEXT_SPLITTER_NAME = load_mod(kb_config_load.get("module"), "TEXT_SPLITTER_NAME") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return TEXT_SPLITTER_NAME return config_kb_workspace.get_config().TEXT_SPLITTER_NAME
def _import_embedding_keyword_file() -> Any: def _import_embedding_keyword_file() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
EMBEDDING_KEYWORD_FILE = load_mod(kb_config_load.get("module"), "EMBEDDING_KEYWORD_FILE") config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return EMBEDDING_KEYWORD_FILE return config_kb_workspace.get_config().EMBEDDING_KEYWORD_FILE
def _import_ConfigModel() -> Any: def _import_ConfigModel() -> Any:
@ -563,6 +595,14 @@ def _import_default_bind_host() -> Any:
return config_server_workspace.get_config().DEFAULT_BIND_HOST return config_server_workspace.get_config().DEFAULT_BIND_HOST
def _import_open_cross_domain() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod")
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
return config_server_workspace.get_config().OPEN_CROSS_DOMAIN
def _import_webui_server() -> Any: def _import_webui_server() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py") server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod") load_mod = server_config_load.get("load_mod")
@ -622,6 +662,15 @@ def __getattr__(name: str) -> Any:
return _import_media_path() return _import_media_path()
elif name == "BASE_TEMP_DIR": elif name == "BASE_TEMP_DIR":
return _import_base_temp_dir() return _import_base_temp_dir()
elif name == "ConfigKb":
return _import_ConfigKb()
elif name == "ConfigKbFactory":
return _import_ConfigKbFactory()
elif name == "ConfigKbWorkSpace":
return _import_ConfigKbWorkSpace()
elif name == "config_kb_workspace":
return _import_config_kb_workspace()
elif name == "DEFAULT_KNOWLEDGE_BASE": elif name == "DEFAULT_KNOWLEDGE_BASE":
return _import_default_knowledge_base() return _import_default_knowledge_base()
elif name == "DEFAULT_VS_TYPE": elif name == "DEFAULT_VS_TYPE":
@ -692,8 +741,10 @@ def __getattr__(name: str) -> Any:
return _import_prompt_templates() return _import_prompt_templates()
elif name == "HTTPX_DEFAULT_TIMEOUT": elif name == "HTTPX_DEFAULT_TIMEOUT":
return _import_httpx_default_timeout() return _import_httpx_default_timeout()
elif name == "OPEN_CROSS_DOMAIN": elif name == "DEFAULT_BIND_HOST":
return _import_default_bind_host() return _import_default_bind_host()
elif name == "OPEN_CROSS_DOMAIN":
return _import_open_cross_domain()
elif name == "WEBUI_SERVER": elif name == "WEBUI_SERVER":
return _import_webui_server() return _import_webui_server()
elif name == "API_SERVER": elif name == "API_SERVER":
@ -748,6 +799,7 @@ __all__ = [
"TOOL_CONFIG", "TOOL_CONFIG",
"PROMPT_TEMPLATES", "PROMPT_TEMPLATES",
"HTTPX_DEFAULT_TIMEOUT", "HTTPX_DEFAULT_TIMEOUT",
"DEFAULT_BIND_HOST",
"OPEN_CROSS_DOMAIN", "OPEN_CROSS_DOMAIN",
"WEBUI_SERVER", "WEBUI_SERVER",
"API_SERVER", "API_SERVER",
@ -764,6 +816,12 @@ __all__ = [
"config_model_workspace", "config_model_workspace",
"ConfigKb",
"ConfigKbFactory",
"ConfigKbWorkSpace",
"config_model_workspace",
"ConfigServer", "ConfigServer",
"ConfigServerFactory", "ConfigServerFactory",
"ConfigServerWorkSpace", "ConfigServerWorkSpace",

View File

@ -1,134 +1,318 @@
import os import os
import json
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import sys import sys
import logging
from typing import Any, Optional, Dict, Tuple
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
import _core_config as core_config
from _basic_config import config_basic_workspace from _basic_config import config_basic_workspace
# 用户数据根目录
DATA_PATH = config_basic_workspace.get_config().DATA_PATH
# 默认使用的知识库 class ConfigKb(core_config.Config):
DEFAULT_KNOWLEDGE_BASE = "samples" DEFAULT_KNOWLEDGE_BASE: Optional[str] = None
"""默认使用的知识库"""
DEFAULT_VS_TYPE: Optional[str] = None
"""默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es"""
CACHED_VS_NUM: Optional[int] = None
"""缓存向量库数量针对FAISS"""
CACHED_MEMO_VS_NUM: Optional[int] = None
"""缓存临时向量库数量针对FAISS用于文件对话"""
CHUNK_SIZE: Optional[int] = None
"""知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)"""
OVERLAP_SIZE: Optional[int] = None
"""知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)"""
VECTOR_SEARCH_TOP_K: Optional[int] = None
"""知识库匹配向量数量"""
SCORE_THRESHOLD: Optional[float] = None
"""知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右"""
DEFAULT_SEARCH_ENGINE: Optional[str] = None
"""默认搜索引擎。可选bing, duckduckgo, metaphor"""
SEARCH_ENGINE_TOP_K: Optional[int] = None
"""搜索引擎匹配结题数量"""
ZH_TITLE_ENHANCE: Optional[bool] = None
"""是否开启中文标题加强,以及标题增强的相关配置"""
PDF_OCR_THRESHOLD: Optional[Tuple[float, float]] = None
"""
PDF OCR 控制只对宽高超过页面一定比例图片宽/页面宽图片高/页面高的图片进行 OCR
这样可以避免 PDF 中一些小图片的干扰提高非扫描版 PDF 处理速度
"""
KB_INFO: Optional[Dict[str, str]] = None
"""每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。"""
KB_ROOT_PATH: Optional[str] = None
"""知识库默认存储路径"""
DB_ROOT_PATH: Optional[str] = None
"""数据库默认存储路径。如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。"""
SQLALCHEMY_DATABASE_URI: Optional[str] = None
"""数据库连接URI"""
kbs_config: Optional[Dict[str, Dict[str, Any]]] = None
"""可选向量库类型及对应配置"""
text_splitter_dict: Optional[Dict[str, Dict[str, Any]]] = None
"""TextSplitter配置项如果你不明白其中的含义就不要修改。"""
TEXT_SPLITTER_NAME: Optional[str] = None
"""TEXT_SPLITTER 名称"""
EMBEDDING_KEYWORD_FILE: Optional[str] = None
"""Embedding模型定制词语的词表文件"""
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es @classmethod
DEFAULT_VS_TYPE = "faiss" def class_name(cls) -> str:
return cls.__name__
# 缓存向量库数量针对FAISS def __str__(self):
CACHED_VS_NUM = 1 return self.to_json()
# 缓存临时向量库数量针对FAISS用于文件对话
CACHED_MEMO_VS_NUM = 10
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 50
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 默认搜索引擎。可选bing, duckduckgo, metaphor
DEFAULT_SEARCH_ENGINE = "duckduckgo"
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
PDF_OCR_THRESHOLD = (0.6, 0.6)
# 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。
KB_INFO = {
"samples": "关于本项目issue的解答",
}
# 通常情况下不需要更改以下内容 @dataclass
class ConfigKbFactory(core_config.ConfigFactory[ConfigKb]):
"""ConfigKb 配置工厂类"""
# 知识库默认存储路径 def __init__(self):
KB_ROOT_PATH = os.path.join(DATA_PATH, "knowledge_base") # 默认使用的知识库
if not os.path.exists(KB_ROOT_PATH): self.DEFAULT_KNOWLEDGE_BASE = "samples"
os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。 # 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。 self.DEFAULT_VS_TYPE = "faiss"
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
# 可选向量库类型及对应配置 # 缓存向量库数量针对FAISS
kbs_config = { self.CACHED_VS_NUM = 1
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"zilliz": {
"host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn",
"port": "19530",
"user": "",
"password": "",
"secure": True,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
},
"es": { # 缓存临时向量库数量针对FAISS用于文件对话
"host": "127.0.0.1", self.CACHED_MEMO_VS_NUM = 10
"port": "9200",
"index_name": "test_index",
"user": "",
"password": ""
},
"milvus_kwargs":{
"search_params":{"metric_type": "L2"}, #在此处增加search_params
"index_params":{"metric_type": "L2","index_type": "HNSW"} # 在此处增加index_params
},
"chromadb": {}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。 # 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
text_splitter_dict = { self.CHUNK_SIZE = 250
"ChineseRecursiveTextSplitter": {
"source": "huggingface", # 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# TEXT_SPLITTER 名称 # 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter" self.OVERLAP_SIZE = 50
# Embedding模型定制词语的词表文件 # 知识库匹配向量数量
EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt" self.VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
self.SCORE_THRESHOLD = 1
# 默认搜索引擎。可选bing, duckduckgo, metaphor
self.DEFAULT_SEARCH_ENGINE = "duckduckgo"
# 搜索引擎匹配结题数量
self.SEARCH_ENGINE_TOP_K = 3
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
self.ZH_TITLE_ENHANCE = False
# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
self.PDF_OCR_THRESHOLD = (0.6, 0.6)
# 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。
self.KB_INFO = {
"samples": "关于本项目issue的解答",
}
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
self.KB_ROOT_PATH = os.path.join(config_basic_workspace.get_config().DATA_PATH, "knowledge_base")
if not os.path.exists(self.KB_ROOT_PATH):
os.mkdir(self.KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
self.DB_ROOT_PATH = os.path.join(self.KB_ROOT_PATH, "info.db")
self.SQLALCHEMY_DATABASE_URI = f"sqlite:///{self.DB_ROOT_PATH}"
# 可选向量库类型及对应配置
self.kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"zilliz": {
"host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn",
"port": "19530",
"user": "",
"password": "",
"secure": True,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
},
"es": {
"host": "127.0.0.1",
"port": "9200",
"index_name": "test_index",
"user": "",
"password": ""
},
"milvus_kwargs": {
"search_params": {"metric_type": "L2"}, #在此处增加search_params
"index_params": {"metric_type": "L2", "index_type": "HNSW"} # 在此处增加index_params
},
"chromadb": {}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。
self.text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", # 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# TEXT_SPLITTER 名称
self.TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"
# Embedding模型定制词语的词表文件
self.EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt"
def get_config(self) -> ConfigKb:
config = ConfigKb()
for key, value in self.__dict__.items():
setattr(config, key, value)
return config
class ConfigKbWorkSpace(core_config.ConfigWorkSpace[ConfigKbFactory, ConfigKb]):
"""
工作空间的配置预设提供ConfigKb建造方法产生实例
"""
config_factory_cls = ConfigKbFactory
def __init__(self):
super().__init__()
def _build_config_factory(self, config_json: Any) -> ConfigKbFactory:
_config_factory = self.config_factory_cls()
if config_json.get("DEFAULT_KNOWLEDGE_BASE"):
_config_factory.DEFAULT_KNOWLEDGE_BASE = config_json.get("DEFAULT_KNOWLEDGE_BASE")
if config_json.get("DEFAULT_VS_TYPE"):
_config_factory.DEFAULT_VS_TYPE = config_json.get("DEFAULT_VS_TYPE")
if config_json.get("CACHED_VS_NUM"):
_config_factory.CACHED_VS_NUM = config_json.get("CACHED_VS_NUM")
if config_json.get("CACHED_MEMO_VS_NUM"):
_config_factory.CACHED_MEMO_VS_NUM = config_json.get("CACHED_MEMO_VS_NUM")
if config_json.get("CHUNK_SIZE"):
_config_factory.CHUNK_SIZE = config_json.get("CHUNK_SIZE")
if config_json.get("OVERLAP_SIZE"):
_config_factory.OVERLAP_SIZE = config_json.get("OVERLAP_SIZE")
if config_json.get("VECTOR_SEARCH_TOP_K"):
_config_factory.VECTOR_SEARCH_TOP_K = config_json.get("VECTOR_SEARCH_TOP_K")
if config_json.get("SCORE_THRESHOLD"):
_config_factory.SCORE_THRESHOLD = config_json.get("SCORE_THRESHOLD")
if config_json.get("DEFAULT_SEARCH_ENGINE"):
_config_factory.DEFAULT_SEARCH_ENGINE = config_json.get("DEFAULT_SEARCH_ENGINE")
if config_json.get("SEARCH_ENGINE_TOP_K"):
_config_factory.SEARCH_ENGINE_TOP_K = config_json.get("SEARCH_ENGINE_TOP_K")
if config_json.get("ZH_TITLE_ENHANCE"):
_config_factory.ZH_TITLE_ENHANCE = config_json.get("ZH_TITLE_ENHANCE")
if config_json.get("PDF_OCR_THRESHOLD"):
_config_factory.PDF_OCR_THRESHOLD = config_json.get("PDF_OCR_THRESHOLD")
if config_json.get("KB_INFO"):
_config_factory.KB_INFO = config_json.get("KB_INFO")
if config_json.get("KB_ROOT_PATH"):
_config_factory.KB_ROOT_PATH = config_json.get("KB_ROOT_PATH")
if config_json.get("DB_ROOT_PATH"):
_config_factory.DB_ROOT_PATH = config_json.get("DB_ROOT_PATH")
if config_json.get("SQLALCHEMY_DATABASE_URI"):
_config_factory.SQLALCHEMY_DATABASE_URI = config_json.get("SQLALCHEMY_DATABASE_URI")
if config_json.get("TEXT_SPLITTER_NAME"):
_config_factory.TEXT_SPLITTER_NAME = config_json.get("TEXT_SPLITTER_NAME")
if config_json.get("EMBEDDING_KEYWORD_FILE"):
_config_factory.EMBEDDING_KEYWORD_FILE = config_json.get("EMBEDDING_KEYWORD_FILE")
return _config_factory
@classmethod
def get_type(cls) -> str:
return ConfigKb.class_name()
def get_config(self) -> ConfigKb:
return self._config_factory.get_config()
def set_default_knowledge_base(self, kb_name: str):
self._config_factory.DEFAULT_KNOWLEDGE_BASE = kb_name
def set_default_vs_type(self, vs_type: str):
self._config_factory.DEFAULT_VS_TYPE = vs_type
def set_cached_vs_num(self, cached_vs_num: int):
self._config_factory.CACHED_VS_NUM = cached_vs_num
def set_cached_memo_vs_num(self, cached_memo_vs_num: int):
self._config_factory.CACHED_MEMO_VS_NUM = cached_memo_vs_num
def set_chunk_size(self, chunk_size: int):
self._config_factory.CHUNK_SIZE = chunk_size
def set_overlap_size(self, overlap_size: int):
self._config_factory.OVERLAP_SIZE = overlap_size
def set_vector_search_top_k(self, vector_search_top_k: int):
self._config_factory.VECTOR_SEARCH_TOP_K = vector_search_top_k
def set_score_threshold(self, score_threshold: float):
self._config_factory.SCORE_THRESHOLD = score_threshold
def set_default_search_engine(self, default_search_engine: str):
self._config_factory.DEFAULT_SEARCH_ENGINE = default_search_engine
def set_search_engine_top_k(self, search_engine_top_k: int):
self._config_factory.SEARCH_ENGINE_TOP_K = search_engine_top_k
def set_zh_title_enhance(self, zh_title_enhance: bool):
self._config_factory.ZH_TITLE_ENHANCE = zh_title_enhance
def set_pdf_ocr_threshold(self, pdf_ocr_threshold: Tuple[float, float]):
self._config_factory.PDF_OCR_THRESHOLD = pdf_ocr_threshold
def set_kb_info(self, kb_info: Dict[str, str]):
self._config_factory.KB_INFO = kb_info
def set_kb_root_path(self, kb_root_path: str):
self._config_factory.KB_ROOT_PATH = kb_root_path
def set_db_root_path(self, db_root_path: str):
self._config_factory.DB_ROOT_PATH = db_root_path
def set_sqlalchemy_database_uri(self, sqlalchemy_database_uri: str):
self._config_factory.SQLALCHEMY_DATABASE_URI = sqlalchemy_database_uri
def set_text_splitter_name(self, text_splitter_name: str):
self._config_factory.TEXT_SPLITTER_NAME = text_splitter_name
def set_embedding_keyword_file(self, embedding_keyword_file: str):
self._config_factory.EMBEDDING_KEYWORD_FILE = embedding_keyword_file
config_kb_workspace: ConfigKbWorkSpace = ConfigKbWorkSpace()

View File

@ -8,6 +8,8 @@ from chatchat.configs import (
ConfigModel, ConfigModel,
ConfigServerWorkSpace, ConfigServerWorkSpace,
ConfigServer, ConfigServer,
ConfigKbWorkSpace,
ConfigKb,
) )
import os import os
@ -118,3 +120,26 @@ def test_config_server_workspace():
assert config.WEBUI_SERVER_PORT == 8000 assert config.WEBUI_SERVER_PORT == 8000
assert config.API_SERVER_PORT == 8001 assert config.API_SERVER_PORT == 8001
config_server_workspace.clear() config_server_workspace.clear()
def test_server_config():
from chatchat.configs import (
HTTPX_DEFAULT_TIMEOUT, OPEN_CROSS_DOMAIN, DEFAULT_BIND_HOST,
WEBUI_SERVER, API_SERVER
)
assert HTTPX_DEFAULT_TIMEOUT is not None
assert OPEN_CROSS_DOMAIN is not None
assert DEFAULT_BIND_HOST is not None
assert WEBUI_SERVER is not None
assert API_SERVER is not None
def test_config_kb_workspace():
config_kb_workspace: ConfigKbWorkSpace = ConfigKbWorkSpace()
assert config_kb_workspace.get_config() is not None
config_kb_workspace.set_default_knowledge_base(kb_name="test")
config_kb_workspace.set_default_vs_type(vs_type="tes")