配置中心知识库信息业务代码

This commit is contained in:
glide-the 2024-06-11 17:37:02 +08:00
parent 35c2f596f6
commit 4f9d63d9f4
3 changed files with 424 additions and 157 deletions

View File

@ -203,164 +203,196 @@ def _import_base_temp_dir() -> Any:
return config_basic_workspace.get_config().BASE_TEMP_DIR
def _import_ConfigKb() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigKb = load_mod(basic_config_load.get("module"), "ConfigKb")
return ConfigKb
def _import_ConfigKbFactory() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigKbFactory = load_mod(basic_config_load.get("module"), "ConfigKbFactory")
return ConfigKbFactory
def _import_ConfigKbWorkSpace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigKbWorkSpace = load_mod(basic_config_load.get("module"), "ConfigKbWorkSpace")
return ConfigKbWorkSpace
def _import_config_kb_workspace() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return config_kb_workspace
def _import_default_knowledge_base() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
DEFAULT_KNOWLEDGE_BASE = load_mod(kb_config_load.get("module"), "DEFAULT_KNOWLEDGE_BASE")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return DEFAULT_KNOWLEDGE_BASE
return config_kb_workspace.get_config().DEFAULT_KNOWLEDGE_BASE
def _import_default_vs_type() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
DEFAULT_VS_TYPE = load_mod(kb_config_load.get("module"), "DEFAULT_VS_TYPE")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return DEFAULT_VS_TYPE
return config_kb_workspace.get_config().DEFAULT_VS_TYPE
def _import_cached_vs_num() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
CACHED_VS_NUM = load_mod(kb_config_load.get("module"), "CACHED_VS_NUM")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return CACHED_VS_NUM
return config_kb_workspace.get_config().CACHED_VS_NUM
def _import_cached_memo_vs_num() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
CACHED_MEMO_VS_NUM = load_mod(kb_config_load.get("module"), "CACHED_MEMO_VS_NUM")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return CACHED_MEMO_VS_NUM
return config_kb_workspace.get_config().CACHED_MEMO_VS_NUM
def _import_chunk_size() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
CHUNK_SIZE = load_mod(kb_config_load.get("module"), "CHUNK_SIZE")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return CHUNK_SIZE
return config_kb_workspace.get_config().CHUNK_SIZE
def _import_overlap_size() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
OVERLAP_SIZE = load_mod(kb_config_load.get("module"), "OVERLAP_SIZE")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return OVERLAP_SIZE
return config_kb_workspace.get_config().OVERLAP_SIZE
def _import_vector_search_top_k() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
VECTOR_SEARCH_TOP_K = load_mod(kb_config_load.get("module"), "VECTOR_SEARCH_TOP_K")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return VECTOR_SEARCH_TOP_K
return config_kb_workspace.get_config().VECTOR_SEARCH_TOP_K
def _import_score_threshold() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
SCORE_THRESHOLD = load_mod(kb_config_load.get("module"), "SCORE_THRESHOLD")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return SCORE_THRESHOLD
return config_kb_workspace.get_config().SCORE_THRESHOLD
def _import_default_search_engine() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
DEFAULT_SEARCH_ENGINE = load_mod(kb_config_load.get("module"), "DEFAULT_SEARCH_ENGINE")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return DEFAULT_SEARCH_ENGINE
return config_kb_workspace.get_config().DEFAULT_SEARCH_ENGINE
def _import_search_engine_top_k() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
SEARCH_ENGINE_TOP_K = load_mod(kb_config_load.get("module"), "SEARCH_ENGINE_TOP_K")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return SEARCH_ENGINE_TOP_K
return config_kb_workspace.get_config().SEARCH_ENGINE_TOP_K
def _import_zh_title_enhance() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
ZH_TITLE_ENHANCE = load_mod(kb_config_load.get("module"), "ZH_TITLE_ENHANCE")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return ZH_TITLE_ENHANCE
return config_kb_workspace.get_config().ZH_TITLE_ENHANCE
def _import_pdf_ocr_threshold() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
PDF_OCR_THRESHOLD = load_mod(kb_config_load.get("module"), "PDF_OCR_THRESHOLD")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return PDF_OCR_THRESHOLD
return config_kb_workspace.get_config().PDF_OCR_THRESHOLD
def _import_kb_info() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
KB_INFO = load_mod(kb_config_load.get("module"), "KB_INFO")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return KB_INFO
return config_kb_workspace.get_config().KB_INFO
def _import_kb_root_path() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
KB_ROOT_PATH = load_mod(kb_config_load.get("module"), "KB_ROOT_PATH")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return KB_ROOT_PATH
return config_kb_workspace.get_config().KB_ROOT_PATH
def _import_db_root_path() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
DB_ROOT_PATH = load_mod(kb_config_load.get("module"), "DB_ROOT_PATH")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return DB_ROOT_PATH
return config_kb_workspace.get_config().DB_ROOT_PATH
def _import_sqlalchemy_database_uri() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
SQLALCHEMY_DATABASE_URI = load_mod(kb_config_load.get("module"), "SQLALCHEMY_DATABASE_URI")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return SQLALCHEMY_DATABASE_URI
return config_kb_workspace.get_config().SQLALCHEMY_DATABASE_URI
def _import_kbs_config() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
kbs_config = load_mod(kb_config_load.get("module"), "kbs_config")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return kbs_config
return config_kb_workspace.get_config().kbs_config
def _import_text_splitter_dict() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
text_splitter_dict = load_mod(kb_config_load.get("module"), "text_splitter_dict")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return text_splitter_dict
return config_kb_workspace.get_config().TEXT_SPLITTER_DICT
def _import_text_splitter_name() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
TEXT_SPLITTER_NAME = load_mod(kb_config_load.get("module"), "TEXT_SPLITTER_NAME")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return TEXT_SPLITTER_NAME
return config_kb_workspace.get_config().TEXT_SPLITTER_NAME
def _import_embedding_keyword_file() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
EMBEDDING_KEYWORD_FILE = load_mod(kb_config_load.get("module"), "EMBEDDING_KEYWORD_FILE")
config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace")
return EMBEDDING_KEYWORD_FILE
return config_kb_workspace.get_config().EMBEDDING_KEYWORD_FILE
def _import_ConfigModel() -> Any:
@ -563,6 +595,14 @@ def _import_default_bind_host() -> Any:
return config_server_workspace.get_config().DEFAULT_BIND_HOST
def _import_open_cross_domain() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod")
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
return config_server_workspace.get_config().OPEN_CROSS_DOMAIN
def _import_webui_server() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod")
@ -622,6 +662,15 @@ def __getattr__(name: str) -> Any:
return _import_media_path()
elif name == "BASE_TEMP_DIR":
return _import_base_temp_dir()
elif name == "ConfigKb":
return _import_ConfigKb()
elif name == "ConfigKbFactory":
return _import_ConfigKbFactory()
elif name == "ConfigKbWorkSpace":
return _import_ConfigKbWorkSpace()
elif name == "config_kb_workspace":
return _import_config_kb_workspace()
elif name == "DEFAULT_KNOWLEDGE_BASE":
return _import_default_knowledge_base()
elif name == "DEFAULT_VS_TYPE":
@ -692,8 +741,10 @@ def __getattr__(name: str) -> Any:
return _import_prompt_templates()
elif name == "HTTPX_DEFAULT_TIMEOUT":
return _import_httpx_default_timeout()
elif name == "OPEN_CROSS_DOMAIN":
elif name == "DEFAULT_BIND_HOST":
return _import_default_bind_host()
elif name == "OPEN_CROSS_DOMAIN":
return _import_open_cross_domain()
elif name == "WEBUI_SERVER":
return _import_webui_server()
elif name == "API_SERVER":
@ -748,6 +799,7 @@ __all__ = [
"TOOL_CONFIG",
"PROMPT_TEMPLATES",
"HTTPX_DEFAULT_TIMEOUT",
"DEFAULT_BIND_HOST",
"OPEN_CROSS_DOMAIN",
"WEBUI_SERVER",
"API_SERVER",
@ -764,6 +816,12 @@ __all__ = [
"config_model_workspace",
"ConfigKb",
"ConfigKbFactory",
"ConfigKbWorkSpace",
"config_model_workspace",
"ConfigServer",
"ConfigServerFactory",
"ConfigServerWorkSpace",

View File

@ -1,134 +1,318 @@
import os
import json
from dataclasses import dataclass
from pathlib import Path
import sys
import logging
from typing import Any, Optional, Dict, Tuple
sys.path.append(str(Path(__file__).parent))
import _core_config as core_config
from _basic_config import config_basic_workspace
# 用户数据根目录
DATA_PATH = config_basic_workspace.get_config().DATA_PATH
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples"
class ConfigKb(core_config.Config):
DEFAULT_KNOWLEDGE_BASE: Optional[str] = None
"""默认使用的知识库"""
DEFAULT_VS_TYPE: Optional[str] = None
"""默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es"""
CACHED_VS_NUM: Optional[int] = None
"""缓存向量库数量针对FAISS"""
CACHED_MEMO_VS_NUM: Optional[int] = None
"""缓存临时向量库数量针对FAISS用于文件对话"""
CHUNK_SIZE: Optional[int] = None
"""知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)"""
OVERLAP_SIZE: Optional[int] = None
"""知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)"""
VECTOR_SEARCH_TOP_K: Optional[int] = None
"""知识库匹配向量数量"""
SCORE_THRESHOLD: Optional[float] = None
"""知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右"""
DEFAULT_SEARCH_ENGINE: Optional[str] = None
"""默认搜索引擎。可选bing, duckduckgo, metaphor"""
SEARCH_ENGINE_TOP_K: Optional[int] = None
"""搜索引擎匹配结题数量"""
ZH_TITLE_ENHANCE: Optional[bool] = None
"""是否开启中文标题加强,以及标题增强的相关配置"""
PDF_OCR_THRESHOLD: Optional[Tuple[float, float]] = None
"""
PDF OCR 控制只对宽高超过页面一定比例图片宽/页面宽图片高/页面高的图片进行 OCR
这样可以避免 PDF 中一些小图片的干扰提高非扫描版 PDF 处理速度
"""
KB_INFO: Optional[Dict[str, str]] = None
"""每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。"""
KB_ROOT_PATH: Optional[str] = None
"""知识库默认存储路径"""
DB_ROOT_PATH: Optional[str] = None
"""数据库默认存储路径。如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。"""
SQLALCHEMY_DATABASE_URI: Optional[str] = None
"""数据库连接URI"""
kbs_config: Optional[Dict[str, Dict[str, Any]]] = None
"""可选向量库类型及对应配置"""
text_splitter_dict: Optional[Dict[str, Dict[str, Any]]] = None
"""TextSplitter配置项如果你不明白其中的含义就不要修改。"""
TEXT_SPLITTER_NAME: Optional[str] = None
"""TEXT_SPLITTER 名称"""
EMBEDDING_KEYWORD_FILE: Optional[str] = None
"""Embedding模型定制词语的词表文件"""
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
DEFAULT_VS_TYPE = "faiss"
@classmethod
def class_name(cls) -> str:
return cls.__name__
# 缓存向量库数量针对FAISS
CACHED_VS_NUM = 1
# 缓存临时向量库数量针对FAISS用于文件对话
CACHED_MEMO_VS_NUM = 10
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 50
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 默认搜索引擎。可选bing, duckduckgo, metaphor
DEFAULT_SEARCH_ENGINE = "duckduckgo"
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
PDF_OCR_THRESHOLD = (0.6, 0.6)
# 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。
KB_INFO = {
"samples": "关于本项目issue的解答",
}
def __str__(self):
return self.to_json()
# 通常情况下不需要更改以下内容
@dataclass
class ConfigKbFactory(core_config.ConfigFactory[ConfigKb]):
"""ConfigKb 配置工厂类"""
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(DATA_PATH, "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)
def __init__(self):
# 默认使用的知识库
self.DEFAULT_KNOWLEDGE_BASE = "samples"
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
self.DEFAULT_VS_TYPE = "faiss"
# 可选向量库类型及对应配置
kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"zilliz": {
"host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn",
"port": "19530",
"user": "",
"password": "",
"secure": True,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
},
# 缓存向量库数量针对FAISS
self.CACHED_VS_NUM = 1
"es": {
"host": "127.0.0.1",
"port": "9200",
"index_name": "test_index",
"user": "",
"password": ""
},
"milvus_kwargs":{
"search_params":{"metric_type": "L2"}, #在此处增加search_params
"index_params":{"metric_type": "L2","index_type": "HNSW"} # 在此处增加index_params
},
"chromadb": {}
}
# 缓存临时向量库数量针对FAISS用于文件对话
self.CACHED_MEMO_VS_NUM = 10
# TextSplitter配置项如果你不明白其中的含义就不要修改。
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", # 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
self.CHUNK_SIZE = 250
# TEXT_SPLITTER 名称
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
self.OVERLAP_SIZE = 50
# Embedding模型定制词语的词表文件
EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt"
# 知识库匹配向量数量
self.VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
self.SCORE_THRESHOLD = 1
# 默认搜索引擎。可选bing, duckduckgo, metaphor
self.DEFAULT_SEARCH_ENGINE = "duckduckgo"
# 搜索引擎匹配结题数量
self.SEARCH_ENGINE_TOP_K = 3
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
self.ZH_TITLE_ENHANCE = False
# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
self.PDF_OCR_THRESHOLD = (0.6, 0.6)
# 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。
self.KB_INFO = {
"samples": "关于本项目issue的解答",
}
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
self.KB_ROOT_PATH = os.path.join(config_basic_workspace.get_config().DATA_PATH, "knowledge_base")
if not os.path.exists(self.KB_ROOT_PATH):
os.mkdir(self.KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
self.DB_ROOT_PATH = os.path.join(self.KB_ROOT_PATH, "info.db")
self.SQLALCHEMY_DATABASE_URI = f"sqlite:///{self.DB_ROOT_PATH}"
# 可选向量库类型及对应配置
self.kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"zilliz": {
"host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn",
"port": "19530",
"user": "",
"password": "",
"secure": True,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
},
"es": {
"host": "127.0.0.1",
"port": "9200",
"index_name": "test_index",
"user": "",
"password": ""
},
"milvus_kwargs": {
"search_params": {"metric_type": "L2"}, #在此处增加search_params
"index_params": {"metric_type": "L2", "index_type": "HNSW"} # 在此处增加index_params
},
"chromadb": {}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。
self.text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", # 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# TEXT_SPLITTER 名称
self.TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"
# Embedding模型定制词语的词表文件
self.EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt"
def get_config(self) -> ConfigKb:
config = ConfigKb()
for key, value in self.__dict__.items():
setattr(config, key, value)
return config
class ConfigKbWorkSpace(core_config.ConfigWorkSpace[ConfigKbFactory, ConfigKb]):
"""
工作空间的配置预设提供ConfigKb建造方法产生实例
"""
config_factory_cls = ConfigKbFactory
def __init__(self):
super().__init__()
def _build_config_factory(self, config_json: Any) -> ConfigKbFactory:
_config_factory = self.config_factory_cls()
if config_json.get("DEFAULT_KNOWLEDGE_BASE"):
_config_factory.DEFAULT_KNOWLEDGE_BASE = config_json.get("DEFAULT_KNOWLEDGE_BASE")
if config_json.get("DEFAULT_VS_TYPE"):
_config_factory.DEFAULT_VS_TYPE = config_json.get("DEFAULT_VS_TYPE")
if config_json.get("CACHED_VS_NUM"):
_config_factory.CACHED_VS_NUM = config_json.get("CACHED_VS_NUM")
if config_json.get("CACHED_MEMO_VS_NUM"):
_config_factory.CACHED_MEMO_VS_NUM = config_json.get("CACHED_MEMO_VS_NUM")
if config_json.get("CHUNK_SIZE"):
_config_factory.CHUNK_SIZE = config_json.get("CHUNK_SIZE")
if config_json.get("OVERLAP_SIZE"):
_config_factory.OVERLAP_SIZE = config_json.get("OVERLAP_SIZE")
if config_json.get("VECTOR_SEARCH_TOP_K"):
_config_factory.VECTOR_SEARCH_TOP_K = config_json.get("VECTOR_SEARCH_TOP_K")
if config_json.get("SCORE_THRESHOLD"):
_config_factory.SCORE_THRESHOLD = config_json.get("SCORE_THRESHOLD")
if config_json.get("DEFAULT_SEARCH_ENGINE"):
_config_factory.DEFAULT_SEARCH_ENGINE = config_json.get("DEFAULT_SEARCH_ENGINE")
if config_json.get("SEARCH_ENGINE_TOP_K"):
_config_factory.SEARCH_ENGINE_TOP_K = config_json.get("SEARCH_ENGINE_TOP_K")
if config_json.get("ZH_TITLE_ENHANCE"):
_config_factory.ZH_TITLE_ENHANCE = config_json.get("ZH_TITLE_ENHANCE")
if config_json.get("PDF_OCR_THRESHOLD"):
_config_factory.PDF_OCR_THRESHOLD = config_json.get("PDF_OCR_THRESHOLD")
if config_json.get("KB_INFO"):
_config_factory.KB_INFO = config_json.get("KB_INFO")
if config_json.get("KB_ROOT_PATH"):
_config_factory.KB_ROOT_PATH = config_json.get("KB_ROOT_PATH")
if config_json.get("DB_ROOT_PATH"):
_config_factory.DB_ROOT_PATH = config_json.get("DB_ROOT_PATH")
if config_json.get("SQLALCHEMY_DATABASE_URI"):
_config_factory.SQLALCHEMY_DATABASE_URI = config_json.get("SQLALCHEMY_DATABASE_URI")
if config_json.get("TEXT_SPLITTER_NAME"):
_config_factory.TEXT_SPLITTER_NAME = config_json.get("TEXT_SPLITTER_NAME")
if config_json.get("EMBEDDING_KEYWORD_FILE"):
_config_factory.EMBEDDING_KEYWORD_FILE = config_json.get("EMBEDDING_KEYWORD_FILE")
return _config_factory
@classmethod
def get_type(cls) -> str:
return ConfigKb.class_name()
def get_config(self) -> ConfigKb:
return self._config_factory.get_config()
def set_default_knowledge_base(self, kb_name: str):
self._config_factory.DEFAULT_KNOWLEDGE_BASE = kb_name
def set_default_vs_type(self, vs_type: str):
self._config_factory.DEFAULT_VS_TYPE = vs_type
def set_cached_vs_num(self, cached_vs_num: int):
self._config_factory.CACHED_VS_NUM = cached_vs_num
def set_cached_memo_vs_num(self, cached_memo_vs_num: int):
self._config_factory.CACHED_MEMO_VS_NUM = cached_memo_vs_num
def set_chunk_size(self, chunk_size: int):
self._config_factory.CHUNK_SIZE = chunk_size
def set_overlap_size(self, overlap_size: int):
self._config_factory.OVERLAP_SIZE = overlap_size
def set_vector_search_top_k(self, vector_search_top_k: int):
self._config_factory.VECTOR_SEARCH_TOP_K = vector_search_top_k
def set_score_threshold(self, score_threshold: float):
self._config_factory.SCORE_THRESHOLD = score_threshold
def set_default_search_engine(self, default_search_engine: str):
self._config_factory.DEFAULT_SEARCH_ENGINE = default_search_engine
def set_search_engine_top_k(self, search_engine_top_k: int):
self._config_factory.SEARCH_ENGINE_TOP_K = search_engine_top_k
def set_zh_title_enhance(self, zh_title_enhance: bool):
self._config_factory.ZH_TITLE_ENHANCE = zh_title_enhance
def set_pdf_ocr_threshold(self, pdf_ocr_threshold: Tuple[float, float]):
self._config_factory.PDF_OCR_THRESHOLD = pdf_ocr_threshold
def set_kb_info(self, kb_info: Dict[str, str]):
self._config_factory.KB_INFO = kb_info
def set_kb_root_path(self, kb_root_path: str):
self._config_factory.KB_ROOT_PATH = kb_root_path
def set_db_root_path(self, db_root_path: str):
self._config_factory.DB_ROOT_PATH = db_root_path
def set_sqlalchemy_database_uri(self, sqlalchemy_database_uri: str):
self._config_factory.SQLALCHEMY_DATABASE_URI = sqlalchemy_database_uri
def set_text_splitter_name(self, text_splitter_name: str):
self._config_factory.TEXT_SPLITTER_NAME = text_splitter_name
def set_embedding_keyword_file(self, embedding_keyword_file: str):
self._config_factory.EMBEDDING_KEYWORD_FILE = embedding_keyword_file
config_kb_workspace: ConfigKbWorkSpace = ConfigKbWorkSpace()

View File

@ -8,6 +8,8 @@ from chatchat.configs import (
ConfigModel,
ConfigServerWorkSpace,
ConfigServer,
ConfigKbWorkSpace,
ConfigKb,
)
import os
@ -117,4 +119,27 @@ def test_config_server_workspace():
assert config.DEFAULT_BIND_HOST == "0.0.0.0"
assert config.WEBUI_SERVER_PORT == 8000
assert config.API_SERVER_PORT == 8001
config_server_workspace.clear()
config_server_workspace.clear()
def test_server_config():
from chatchat.configs import (
HTTPX_DEFAULT_TIMEOUT, OPEN_CROSS_DOMAIN, DEFAULT_BIND_HOST,
WEBUI_SERVER, API_SERVER
)
assert HTTPX_DEFAULT_TIMEOUT is not None
assert OPEN_CROSS_DOMAIN is not None
assert DEFAULT_BIND_HOST is not None
assert WEBUI_SERVER is not None
assert API_SERVER is not None
def test_config_kb_workspace():
config_kb_workspace: ConfigKbWorkSpace = ConfigKbWorkSpace()
assert config_kb_workspace.get_config() is not None
config_kb_workspace.set_default_knowledge_base(kb_name="test")
config_kb_workspace.set_default_vs_type(vs_type="tes")