diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index c20d2e39..6a68fb03 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -203,164 +203,196 @@ def _import_base_temp_dir() -> Any: return config_basic_workspace.get_config().BASE_TEMP_DIR +def _import_ConfigKb() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_kb_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigKb = load_mod(basic_config_load.get("module"), "ConfigKb") + + return ConfigKb + + +def _import_ConfigKbFactory() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_kb_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigKbFactory = load_mod(basic_config_load.get("module"), "ConfigKbFactory") + + return ConfigKbFactory + + +def _import_ConfigKbWorkSpace() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_kb_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigKbWorkSpace = load_mod(basic_config_load.get("module"), "ConfigKbWorkSpace") + + return ConfigKbWorkSpace + + +def _import_config_kb_workspace() -> Any: + kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") + load_mod = kb_config_load.get("load_mod") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") + + return config_kb_workspace + + def _import_default_knowledge_base() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - DEFAULT_KNOWLEDGE_BASE = load_mod(kb_config_load.get("module"), "DEFAULT_KNOWLEDGE_BASE") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return DEFAULT_KNOWLEDGE_BASE + return config_kb_workspace.get_config().DEFAULT_KNOWLEDGE_BASE def _import_default_vs_type() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - DEFAULT_VS_TYPE = load_mod(kb_config_load.get("module"), "DEFAULT_VS_TYPE") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return DEFAULT_VS_TYPE + return config_kb_workspace.get_config().DEFAULT_VS_TYPE def _import_cached_vs_num() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - CACHED_VS_NUM = load_mod(kb_config_load.get("module"), "CACHED_VS_NUM") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return CACHED_VS_NUM + return config_kb_workspace.get_config().CACHED_VS_NUM def _import_cached_memo_vs_num() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - CACHED_MEMO_VS_NUM = load_mod(kb_config_load.get("module"), "CACHED_MEMO_VS_NUM") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return CACHED_MEMO_VS_NUM + return config_kb_workspace.get_config().CACHED_MEMO_VS_NUM def _import_chunk_size() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - CHUNK_SIZE = load_mod(kb_config_load.get("module"), "CHUNK_SIZE") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return CHUNK_SIZE + return config_kb_workspace.get_config().CHUNK_SIZE def _import_overlap_size() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - OVERLAP_SIZE = load_mod(kb_config_load.get("module"), "OVERLAP_SIZE") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return OVERLAP_SIZE + return config_kb_workspace.get_config().OVERLAP_SIZE def _import_vector_search_top_k() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - VECTOR_SEARCH_TOP_K = load_mod(kb_config_load.get("module"), "VECTOR_SEARCH_TOP_K") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return VECTOR_SEARCH_TOP_K + return config_kb_workspace.get_config().VECTOR_SEARCH_TOP_K def _import_score_threshold() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - SCORE_THRESHOLD = load_mod(kb_config_load.get("module"), "SCORE_THRESHOLD") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return SCORE_THRESHOLD + return config_kb_workspace.get_config().SCORE_THRESHOLD def _import_default_search_engine() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - DEFAULT_SEARCH_ENGINE = load_mod(kb_config_load.get("module"), "DEFAULT_SEARCH_ENGINE") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return DEFAULT_SEARCH_ENGINE + return config_kb_workspace.get_config().DEFAULT_SEARCH_ENGINE def _import_search_engine_top_k() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - SEARCH_ENGINE_TOP_K = load_mod(kb_config_load.get("module"), "SEARCH_ENGINE_TOP_K") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return SEARCH_ENGINE_TOP_K + return config_kb_workspace.get_config().SEARCH_ENGINE_TOP_K def _import_zh_title_enhance() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - ZH_TITLE_ENHANCE = load_mod(kb_config_load.get("module"), "ZH_TITLE_ENHANCE") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return ZH_TITLE_ENHANCE + return config_kb_workspace.get_config().ZH_TITLE_ENHANCE def _import_pdf_ocr_threshold() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - PDF_OCR_THRESHOLD = load_mod(kb_config_load.get("module"), "PDF_OCR_THRESHOLD") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return PDF_OCR_THRESHOLD + return config_kb_workspace.get_config().PDF_OCR_THRESHOLD def _import_kb_info() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - KB_INFO = load_mod(kb_config_load.get("module"), "KB_INFO") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return KB_INFO + return config_kb_workspace.get_config().KB_INFO def _import_kb_root_path() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - KB_ROOT_PATH = load_mod(kb_config_load.get("module"), "KB_ROOT_PATH") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return KB_ROOT_PATH + return config_kb_workspace.get_config().KB_ROOT_PATH def _import_db_root_path() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - DB_ROOT_PATH = load_mod(kb_config_load.get("module"), "DB_ROOT_PATH") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return DB_ROOT_PATH + return config_kb_workspace.get_config().DB_ROOT_PATH def _import_sqlalchemy_database_uri() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - SQLALCHEMY_DATABASE_URI = load_mod(kb_config_load.get("module"), "SQLALCHEMY_DATABASE_URI") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return SQLALCHEMY_DATABASE_URI + return config_kb_workspace.get_config().SQLALCHEMY_DATABASE_URI def _import_kbs_config() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - kbs_config = load_mod(kb_config_load.get("module"), "kbs_config") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return kbs_config + return config_kb_workspace.get_config().kbs_config def _import_text_splitter_dict() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - text_splitter_dict = load_mod(kb_config_load.get("module"), "text_splitter_dict") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return text_splitter_dict + return config_kb_workspace.get_config().TEXT_SPLITTER_DICT def _import_text_splitter_name() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - TEXT_SPLITTER_NAME = load_mod(kb_config_load.get("module"), "TEXT_SPLITTER_NAME") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return TEXT_SPLITTER_NAME + return config_kb_workspace.get_config().TEXT_SPLITTER_NAME def _import_embedding_keyword_file() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") - EMBEDDING_KEYWORD_FILE = load_mod(kb_config_load.get("module"), "EMBEDDING_KEYWORD_FILE") + config_kb_workspace = load_mod(kb_config_load.get("module"), "config_kb_workspace") - return EMBEDDING_KEYWORD_FILE + return config_kb_workspace.get_config().EMBEDDING_KEYWORD_FILE def _import_ConfigModel() -> Any: @@ -563,6 +595,14 @@ def _import_default_bind_host() -> Any: return config_server_workspace.get_config().DEFAULT_BIND_HOST +def _import_open_cross_domain() -> Any: + server_config_load = CONFIG_IMPORTS.get("_server_config.py") + load_mod = server_config_load.get("load_mod") + config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") + + return config_server_workspace.get_config().OPEN_CROSS_DOMAIN + + def _import_webui_server() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") @@ -622,6 +662,15 @@ def __getattr__(name: str) -> Any: return _import_media_path() elif name == "BASE_TEMP_DIR": return _import_base_temp_dir() + + elif name == "ConfigKb": + return _import_ConfigKb() + elif name == "ConfigKbFactory": + return _import_ConfigKbFactory() + elif name == "ConfigKbWorkSpace": + return _import_ConfigKbWorkSpace() + elif name == "config_kb_workspace": + return _import_config_kb_workspace() elif name == "DEFAULT_KNOWLEDGE_BASE": return _import_default_knowledge_base() elif name == "DEFAULT_VS_TYPE": @@ -692,8 +741,10 @@ def __getattr__(name: str) -> Any: return _import_prompt_templates() elif name == "HTTPX_DEFAULT_TIMEOUT": return _import_httpx_default_timeout() - elif name == "OPEN_CROSS_DOMAIN": + elif name == "DEFAULT_BIND_HOST": return _import_default_bind_host() + elif name == "OPEN_CROSS_DOMAIN": + return _import_open_cross_domain() elif name == "WEBUI_SERVER": return _import_webui_server() elif name == "API_SERVER": @@ -748,6 +799,7 @@ __all__ = [ "TOOL_CONFIG", "PROMPT_TEMPLATES", "HTTPX_DEFAULT_TIMEOUT", + "DEFAULT_BIND_HOST", "OPEN_CROSS_DOMAIN", "WEBUI_SERVER", "API_SERVER", @@ -764,6 +816,12 @@ __all__ = [ "config_model_workspace", + "ConfigKb", + "ConfigKbFactory", + "ConfigKbWorkSpace", + + "config_model_workspace", + "ConfigServer", "ConfigServerFactory", "ConfigServerWorkSpace", diff --git a/libs/chatchat-server/chatchat/configs/_kb_config.py b/libs/chatchat-server/chatchat/configs/_kb_config.py index a13eb87b..3990f3e3 100644 --- a/libs/chatchat-server/chatchat/configs/_kb_config.py +++ b/libs/chatchat-server/chatchat/configs/_kb_config.py @@ -1,134 +1,318 @@ import os +import json +from dataclasses import dataclass from pathlib import Path - import sys +import logging +from typing import Any, Optional, Dict, Tuple + sys.path.append(str(Path(__file__).parent)) +import _core_config as core_config from _basic_config import config_basic_workspace -# 用户数据根目录 -DATA_PATH = config_basic_workspace.get_config().DATA_PATH -# 默认使用的知识库 -DEFAULT_KNOWLEDGE_BASE = "samples" +class ConfigKb(core_config.Config): + DEFAULT_KNOWLEDGE_BASE: Optional[str] = None + """默认使用的知识库""" + DEFAULT_VS_TYPE: Optional[str] = None + """默认向量库/全文检索引擎类型。可选:faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es""" + CACHED_VS_NUM: Optional[int] = None + """缓存向量库数量(针对FAISS)""" + CACHED_MEMO_VS_NUM: Optional[int] = None + """缓存临时向量库数量(针对FAISS),用于文件对话""" + CHUNK_SIZE: Optional[int] = None + """知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)""" + OVERLAP_SIZE: Optional[int] = None + """知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)""" + VECTOR_SEARCH_TOP_K: Optional[int] = None + """知识库匹配向量数量""" + SCORE_THRESHOLD: Optional[float] = None + """知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右""" + DEFAULT_SEARCH_ENGINE: Optional[str] = None + """默认搜索引擎。可选:bing, duckduckgo, metaphor""" + SEARCH_ENGINE_TOP_K: Optional[int] = None + """搜索引擎匹配结题数量""" + ZH_TITLE_ENHANCE: Optional[bool] = None + """是否开启中文标题加强,以及标题增强的相关配置""" + PDF_OCR_THRESHOLD: Optional[Tuple[float, float]] = None + """ + PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。 + 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度 + """ + KB_INFO: Optional[Dict[str, str]] = None + """每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。""" + KB_ROOT_PATH: Optional[str] = None + """知识库默认存储路径""" + DB_ROOT_PATH: Optional[str] = None + """数据库默认存储路径。如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。""" + SQLALCHEMY_DATABASE_URI: Optional[str] = None + """数据库连接URI""" + kbs_config: Optional[Dict[str, Dict[str, Any]]] = None + """可选向量库类型及对应配置""" + text_splitter_dict: Optional[Dict[str, Dict[str, Any]]] = None + """TextSplitter配置项,如果你不明白其中的含义,就不要修改。""" + TEXT_SPLITTER_NAME: Optional[str] = None + """TEXT_SPLITTER 名称""" + EMBEDDING_KEYWORD_FILE: Optional[str] = None + """Embedding模型定制词语的词表文件""" -# 默认向量库/全文检索引擎类型。可选:faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es -DEFAULT_VS_TYPE = "faiss" + @classmethod + def class_name(cls) -> str: + return cls.__name__ -# 缓存向量库数量(针对FAISS) -CACHED_VS_NUM = 1 - -# 缓存临时向量库数量(针对FAISS),用于文件对话 -CACHED_MEMO_VS_NUM = 10 - -# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) -CHUNK_SIZE = 250 - -# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) -OVERLAP_SIZE = 50 - -# 知识库匹配向量数量 -VECTOR_SEARCH_TOP_K = 3 - -# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 -SCORE_THRESHOLD = 1 - -# 默认搜索引擎。可选:bing, duckduckgo, metaphor -DEFAULT_SEARCH_ENGINE = "duckduckgo" - -# 搜索引擎匹配结题数量 -SEARCH_ENGINE_TOP_K = 3 - -# 是否开启中文标题加强,以及标题增强的相关配置 -# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; -# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 -ZH_TITLE_ENHANCE = False - -# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。 -# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度 -PDF_OCR_THRESHOLD = (0.6, 0.6) - -# 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。 -KB_INFO = { - "samples": "关于本项目issue的解答", -} + def __str__(self): + return self.to_json() -# 通常情况下不需要更改以下内容 +@dataclass +class ConfigKbFactory(core_config.ConfigFactory[ConfigKb]): + """ConfigKb 配置工厂类""" -# 知识库默认存储路径 -KB_ROOT_PATH = os.path.join(DATA_PATH, "knowledge_base") -if not os.path.exists(KB_ROOT_PATH): - os.mkdir(KB_ROOT_PATH) + def __init__(self): + # 默认使用的知识库 + self.DEFAULT_KNOWLEDGE_BASE = "samples" -# 数据库默认存储路径。 -# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。 -DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") -SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}" + # 默认向量库/全文检索引擎类型。可选:faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es + self.DEFAULT_VS_TYPE = "faiss" -# 可选向量库类型及对应配置 -kbs_config = { - "faiss": { - }, - "milvus": { - "host": "127.0.0.1", - "port": "19530", - "user": "", - "password": "", - "secure": False, - }, - "zilliz": { - "host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn", - "port": "19530", - "user": "", - "password": "", - "secure": True, - }, - "pg": { - "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat", - }, + # 缓存向量库数量(针对FAISS) + self.CACHED_VS_NUM = 1 - "es": { - "host": "127.0.0.1", - "port": "9200", - "index_name": "test_index", - "user": "", - "password": "" - }, - "milvus_kwargs":{ - "search_params":{"metric_type": "L2"}, #在此处增加search_params - "index_params":{"metric_type": "L2","index_type": "HNSW"} # 在此处增加index_params - }, - "chromadb": {} -} + # 缓存临时向量库数量(针对FAISS),用于文件对话 + self.CACHED_MEMO_VS_NUM = 10 -# TextSplitter配置项,如果你不明白其中的含义,就不要修改。 -text_splitter_dict = { - "ChineseRecursiveTextSplitter": { - "source": "huggingface", # 选择tiktoken则使用openai的方法 - "tokenizer_name_or_path": "", - }, - "SpacyTextSplitter": { - "source": "huggingface", - "tokenizer_name_or_path": "gpt2", - }, - "RecursiveCharacterTextSplitter": { - "source": "tiktoken", - "tokenizer_name_or_path": "cl100k_base", - }, - "MarkdownHeaderTextSplitter": { - "headers_to_split_on": - [ - ("#", "head1"), - ("##", "head2"), - ("###", "head3"), - ("####", "head4"), - ] - }, -} + # 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) + self.CHUNK_SIZE = 250 -# TEXT_SPLITTER 名称 -TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter" + # 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) + self.OVERLAP_SIZE = 50 -# Embedding模型定制词语的词表文件 -EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt" + # 知识库匹配向量数量 + self.VECTOR_SEARCH_TOP_K = 3 + + # 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 + self.SCORE_THRESHOLD = 1 + + # 默认搜索引擎。可选:bing, duckduckgo, metaphor + self.DEFAULT_SEARCH_ENGINE = "duckduckgo" + + # 搜索引擎匹配结题数量 + self.SEARCH_ENGINE_TOP_K = 3 + + # 是否开启中文标题加强,以及标题增强的相关配置 + # 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; + # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 + self.ZH_TITLE_ENHANCE = False + + # PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。 + # 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度 + self.PDF_OCR_THRESHOLD = (0.6, 0.6) + + # 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。 + self.KB_INFO = { + "samples": "关于本项目issue的解答", + } + + # 通常情况下不需要更改以下内容 + + # 知识库默认存储路径 + self.KB_ROOT_PATH = os.path.join(config_basic_workspace.get_config().DATA_PATH, "knowledge_base") + if not os.path.exists(self.KB_ROOT_PATH): + os.mkdir(self.KB_ROOT_PATH) + + # 数据库默认存储路径。 + # 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。 + self.DB_ROOT_PATH = os.path.join(self.KB_ROOT_PATH, "info.db") + self.SQLALCHEMY_DATABASE_URI = f"sqlite:///{self.DB_ROOT_PATH}" + + # 可选向量库类型及对应配置 + self.kbs_config = { + "faiss": { + }, + "milvus": { + "host": "127.0.0.1", + "port": "19530", + "user": "", + "password": "", + "secure": False, + }, + "zilliz": { + "host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn", + "port": "19530", + "user": "", + "password": "", + "secure": True, + }, + "pg": { + "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat", + }, + + "es": { + "host": "127.0.0.1", + "port": "9200", + "index_name": "test_index", + "user": "", + "password": "" + }, + "milvus_kwargs": { + "search_params": {"metric_type": "L2"}, #在此处增加search_params + "index_params": {"metric_type": "L2", "index_type": "HNSW"} # 在此处增加index_params + }, + "chromadb": {} + } + + # TextSplitter配置项,如果你不明白其中的含义,就不要修改。 + self.text_splitter_dict = { + "ChineseRecursiveTextSplitter": { + "source": "huggingface", # 选择tiktoken则使用openai的方法 + "tokenizer_name_or_path": "", + }, + "SpacyTextSplitter": { + "source": "huggingface", + "tokenizer_name_or_path": "gpt2", + }, + "RecursiveCharacterTextSplitter": { + "source": "tiktoken", + "tokenizer_name_or_path": "cl100k_base", + }, + "MarkdownHeaderTextSplitter": { + "headers_to_split_on": + [ + ("#", "head1"), + ("##", "head2"), + ("###", "head3"), + ("####", "head4"), + ] + }, + } + + # TEXT_SPLITTER 名称 + self.TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter" + + # Embedding模型定制词语的词表文件 + self.EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt" + + def get_config(self) -> ConfigKb: + config = ConfigKb() + for key, value in self.__dict__.items(): + setattr(config, key, value) + + return config + + +class ConfigKbWorkSpace(core_config.ConfigWorkSpace[ConfigKbFactory, ConfigKb]): + """ + 工作空间的配置预设,提供ConfigKb建造方法产生实例。 + """ + config_factory_cls = ConfigKbFactory + + def __init__(self): + super().__init__() + + def _build_config_factory(self, config_json: Any) -> ConfigKbFactory: + _config_factory = self.config_factory_cls() + if config_json.get("DEFAULT_KNOWLEDGE_BASE"): + _config_factory.DEFAULT_KNOWLEDGE_BASE = config_json.get("DEFAULT_KNOWLEDGE_BASE") + if config_json.get("DEFAULT_VS_TYPE"): + _config_factory.DEFAULT_VS_TYPE = config_json.get("DEFAULT_VS_TYPE") + if config_json.get("CACHED_VS_NUM"): + _config_factory.CACHED_VS_NUM = config_json.get("CACHED_VS_NUM") + if config_json.get("CACHED_MEMO_VS_NUM"): + _config_factory.CACHED_MEMO_VS_NUM = config_json.get("CACHED_MEMO_VS_NUM") + if config_json.get("CHUNK_SIZE"): + _config_factory.CHUNK_SIZE = config_json.get("CHUNK_SIZE") + if config_json.get("OVERLAP_SIZE"): + _config_factory.OVERLAP_SIZE = config_json.get("OVERLAP_SIZE") + if config_json.get("VECTOR_SEARCH_TOP_K"): + _config_factory.VECTOR_SEARCH_TOP_K = config_json.get("VECTOR_SEARCH_TOP_K") + if config_json.get("SCORE_THRESHOLD"): + _config_factory.SCORE_THRESHOLD = config_json.get("SCORE_THRESHOLD") + if config_json.get("DEFAULT_SEARCH_ENGINE"): + _config_factory.DEFAULT_SEARCH_ENGINE = config_json.get("DEFAULT_SEARCH_ENGINE") + if config_json.get("SEARCH_ENGINE_TOP_K"): + _config_factory.SEARCH_ENGINE_TOP_K = config_json.get("SEARCH_ENGINE_TOP_K") + if config_json.get("ZH_TITLE_ENHANCE"): + _config_factory.ZH_TITLE_ENHANCE = config_json.get("ZH_TITLE_ENHANCE") + if config_json.get("PDF_OCR_THRESHOLD"): + _config_factory.PDF_OCR_THRESHOLD = config_json.get("PDF_OCR_THRESHOLD") + if config_json.get("KB_INFO"): + _config_factory.KB_INFO = config_json.get("KB_INFO") + if config_json.get("KB_ROOT_PATH"): + _config_factory.KB_ROOT_PATH = config_json.get("KB_ROOT_PATH") + if config_json.get("DB_ROOT_PATH"): + _config_factory.DB_ROOT_PATH = config_json.get("DB_ROOT_PATH") + if config_json.get("SQLALCHEMY_DATABASE_URI"): + _config_factory.SQLALCHEMY_DATABASE_URI = config_json.get("SQLALCHEMY_DATABASE_URI") + + if config_json.get("TEXT_SPLITTER_NAME"): + _config_factory.TEXT_SPLITTER_NAME = config_json.get("TEXT_SPLITTER_NAME") + + if config_json.get("EMBEDDING_KEYWORD_FILE"): + _config_factory.EMBEDDING_KEYWORD_FILE = config_json.get("EMBEDDING_KEYWORD_FILE") + + return _config_factory + + @classmethod + def get_type(cls) -> str: + return ConfigKb.class_name() + + def get_config(self) -> ConfigKb: + return self._config_factory.get_config() + + def set_default_knowledge_base(self, kb_name: str): + self._config_factory.DEFAULT_KNOWLEDGE_BASE = kb_name + + def set_default_vs_type(self, vs_type: str): + self._config_factory.DEFAULT_VS_TYPE = vs_type + + def set_cached_vs_num(self, cached_vs_num: int): + self._config_factory.CACHED_VS_NUM = cached_vs_num + + def set_cached_memo_vs_num(self, cached_memo_vs_num: int): + self._config_factory.CACHED_MEMO_VS_NUM = cached_memo_vs_num + + def set_chunk_size(self, chunk_size: int): + self._config_factory.CHUNK_SIZE = chunk_size + + def set_overlap_size(self, overlap_size: int): + self._config_factory.OVERLAP_SIZE = overlap_size + + def set_vector_search_top_k(self, vector_search_top_k: int): + self._config_factory.VECTOR_SEARCH_TOP_K = vector_search_top_k + + def set_score_threshold(self, score_threshold: float): + self._config_factory.SCORE_THRESHOLD = score_threshold + + def set_default_search_engine(self, default_search_engine: str): + self._config_factory.DEFAULT_SEARCH_ENGINE = default_search_engine + + def set_search_engine_top_k(self, search_engine_top_k: int): + self._config_factory.SEARCH_ENGINE_TOP_K = search_engine_top_k + + def set_zh_title_enhance(self, zh_title_enhance: bool): + self._config_factory.ZH_TITLE_ENHANCE = zh_title_enhance + + def set_pdf_ocr_threshold(self, pdf_ocr_threshold: Tuple[float, float]): + self._config_factory.PDF_OCR_THRESHOLD = pdf_ocr_threshold + + def set_kb_info(self, kb_info: Dict[str, str]): + self._config_factory.KB_INFO = kb_info + + def set_kb_root_path(self, kb_root_path: str): + self._config_factory.KB_ROOT_PATH = kb_root_path + + def set_db_root_path(self, db_root_path: str): + self._config_factory.DB_ROOT_PATH = db_root_path + + def set_sqlalchemy_database_uri(self, sqlalchemy_database_uri: str): + self._config_factory.SQLALCHEMY_DATABASE_URI = sqlalchemy_database_uri + + def set_text_splitter_name(self, text_splitter_name: str): + self._config_factory.TEXT_SPLITTER_NAME = text_splitter_name + + def set_embedding_keyword_file(self, embedding_keyword_file: str): + self._config_factory.EMBEDDING_KEYWORD_FILE = embedding_keyword_file + + +config_kb_workspace: ConfigKbWorkSpace = ConfigKbWorkSpace() diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py index 35d79f2d..062e35bb 100644 --- a/libs/chatchat-server/tests/unit_tests/config/test_config.py +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -8,6 +8,8 @@ from chatchat.configs import ( ConfigModel, ConfigServerWorkSpace, ConfigServer, + ConfigKbWorkSpace, + ConfigKb, ) import os @@ -117,4 +119,27 @@ def test_config_server_workspace(): assert config.DEFAULT_BIND_HOST == "0.0.0.0" assert config.WEBUI_SERVER_PORT == 8000 assert config.API_SERVER_PORT == 8001 - config_server_workspace.clear() \ No newline at end of file + config_server_workspace.clear() + + +def test_server_config(): + from chatchat.configs import ( + HTTPX_DEFAULT_TIMEOUT, OPEN_CROSS_DOMAIN, DEFAULT_BIND_HOST, + WEBUI_SERVER, API_SERVER + ) + assert HTTPX_DEFAULT_TIMEOUT is not None + assert OPEN_CROSS_DOMAIN is not None + assert DEFAULT_BIND_HOST is not None + assert WEBUI_SERVER is not None + assert API_SERVER is not None + + +def test_config_kb_workspace(): + + config_kb_workspace: ConfigKbWorkSpace = ConfigKbWorkSpace() + + assert config_kb_workspace.get_config() is not None + + config_kb_workspace.set_default_knowledge_base(kb_name="test") + config_kb_workspace.set_default_vs_type(vs_type="tes") + \ No newline at end of file