diff --git a/libs/chatchat-server/chatchat/config_work_space.py b/libs/chatchat-server/chatchat/config_work_space.py index a7ea0c9b..e113956b 100644 --- a/libs/chatchat-server/chatchat/config_work_space.py +++ b/libs/chatchat-server/chatchat/config_work_space.py @@ -1,8 +1,10 @@ from chatchat.configs import ( config_basic_workspace, config_model_workspace, + config_server_workspace, + config_kb_workspace, ) - +import ast # We cannot lazy-load click here because its used via decorators. import click @@ -71,7 +73,9 @@ def model(**kwargs): config_model_workspace.set_support_agent_models(support_agent_models=kwargs["support_agent_models"]) if kwargs["model_providers_cfg_path_config"]: - config_model_workspace.set_model_providers_cfg_path_config(model_providers_cfg_path_config=kwargs["model_providers_cfg_path_config"]) + config_model_workspace.set_model_providers_cfg_path_config( + model_providers_cfg_path_config=kwargs["model_providers_cfg_path_config"] + ) if kwargs["model_providers_cfg_host"]: config_model_workspace.set_model_providers_cfg_host(model_providers_cfg_host=kwargs["model_providers_cfg_host"]) @@ -96,25 +100,97 @@ def model(**kwargs): def server(**kwargs): if kwargs["httpx_default_timeout"]: - config_basic_workspace.set_httpx_default_timeout(httpx_default_timeout=kwargs["httpx_default_timeout"]) + config_server_workspace.set_httpx_default_timeout(httpx_default_timeout=kwargs["httpx_default_timeout"]) if kwargs["open_cross_domain"]: if kwargs["open_cross_domain"].lower() == "true": - config_basic_workspace.set_open_cross_domain(True) + config_server_workspace.set_open_cross_domain(True) else: - config_basic_workspace.set_open_cross_domain(False) + config_server_workspace.set_open_cross_domain(False) if kwargs["default_bind_host"]: - config_basic_workspace.set_default_bind_host(default_bind_host=kwargs["default_bind_host"]) + config_server_workspace.set_default_bind_host(default_bind_host=kwargs["default_bind_host"]) if kwargs["webui_server_port"]: - config_basic_workspace.set_webui_server_port(webui_server_port=kwargs["webui_server_port"]) + config_server_workspace.set_webui_server_port(webui_server_port=kwargs["webui_server_port"]) if kwargs["api_server_port"]: - config_basic_workspace.set_api_server_port(api_server_port=kwargs["api_server_port"]) + config_server_workspace.set_api_server_port(api_server_port=kwargs["api_server_port"]) if kwargs["clear"]: - config_model_workspace.clear() + config_server_workspace.clear() if kwargs["show"]: - print(config_model_workspace.get_config()) + print(config_server_workspace.get_config()) + + +@main.command("kb", help="知识库配置") +@click.option("--set_default_knowledge_base", help="设置默认知识库") +@click.option("--set_default_vs_type", help="设置默认vs类型") +@click.option("--set_cached_vs_num", type=int, help="设置缓存vs数量") +@click.option("--set_cached_memo_vs_num", type=int, help="设置缓存memo vs数量") +@click.option("--set_chunk_size", type=int, help="设置chunk大小") +@click.option("--set_overlap_size", type=int, help="设置overlap大小") +@click.option("--set_vector_search_top_k", type=int, help="设置vector search top k") +@click.option("--set_score_threshold", type=float, help="设置score阈值") +@click.option("--set_default_search_engine", help="设置默认搜索引擎") +@click.option("--set_search_engine_top_k", type=int, help="设置搜索引擎top k") +@click.option("--set_zh_title_enhance", type=click.Choice(["true", "false"]), help="是否开启中文标题增强") +@click.option('--pdf-ocr-threshold', type=(float, float), help='pdf ocr threshold') +@click.option('--set_kb_info', type=str, help='''每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用, + 没写则没有介绍,不会被Agent调用。 + Example: \'{"samples": "关于本项目issue的解答"}\' + ''') +@click.option("--set_kb_root_path", help="设置知识库根路径") +@click.option("--set_db_root_path", help="设置db根路径") +@click.option("--set_sqlalchemy_database_uri", help="设置sqlalchemy数据库uri") +@click.option("--set_text_splitter_name", help="设置text splitter名称") +@click.option("--set_embedding_keyword_file", help="设置embedding关键词文件") +@click.option("--clear", is_flag=True, help="清除配置") +@click.option("--show", is_flag=True, help="显示配置") +def kb(**kwargs): + + if kwargs["set_default_knowledge_base"]: + config_kb_workspace.set_default_knowledge_base(default_knowledge_base=kwargs["set_default_knowledge_base"]) + if kwargs["set_default_vs_type"]: + config_kb_workspace.set_default_vs_type(default_vs_type=kwargs["set_default_vs_type"]) + if kwargs["set_cached_vs_num"]: + config_kb_workspace.set_cached_vs_num(cached_vs_num=kwargs["set_cached_vs_num"]) + if kwargs["set_cached_memo_vs_num"]: + config_kb_workspace.set_cached_memo_vs_num(cached_memo_vs_num=kwargs["set_cached_memo_vs_num"]) + if kwargs["set_chunk_size"]: + config_kb_workspace.set_chunk_size(chunk_size=kwargs["set_chunk_size"]) + if kwargs["set_overlap_size"]: + config_kb_workspace.set_overlap_size(overlap_size=kwargs["set_overlap_size"]) + if kwargs["set_vector_search_top_k"]: + config_kb_workspace.set_vector_search_top_k(vector_search_top_k=kwargs["set_vector_search_top_k"]) + if kwargs["set_score_threshold"]: + config_kb_workspace.set_score_threshold(score_threshold=kwargs["set_score_threshold"]) + if kwargs["set_default_search_engine"]: + config_kb_workspace.set_default_search_engine(default_search_engine=kwargs["set_default_search_engine"]) + if kwargs["set_search_engine_top_k"]: + config_model_workspace.set_search_engine_top_k(search_engine_top_k=kwargs["set_search_engine_top_k"]) + if kwargs["set_zh_title_enhance"]: + if kwargs["set_zh_title_enhance"].lower() == "true": + config_kb_workspace.set_zh_title_enhance(True) + else: + config_kb_workspace.set_zh_title_enhance(False) + if kwargs["pdf_ocr_threshold"]: + config_kb_workspace.set_pdf_ocr_threshold(pdf_ocr_threshold=kwargs["pdf_ocr_threshold"]) + if kwargs["set_kb_info"]: + config_kb_workspace.set_kb_info(kb_info=ast.literal_eval(kwargs["set_kb_info"])) + if kwargs["set_kb_root_path"]: + config_kb_workspace.set_kb_root_path(kb_root_path=kwargs["set_kb_root_path"]) + if kwargs["set_db_root_path"]: + config_kb_workspace.set_db_root_path(db_root_path=kwargs["set_db_root_path"]) + if kwargs["set_sqlalchemy_database_uri"]: + config_kb_workspace.set_sqlalchemy_database_uri(sqlalchemy_database_uri=kwargs["set_sqlalchemy_database_uri"]) + if kwargs["set_text_splitter_name"]: + config_kb_workspace.set_text_splitter_name(text_splitter_name=kwargs["set_text_splitter_name"]) + if kwargs["set_embedding_keyword_file"]: + config_kb_workspace.set_embedding_keyword_file(embedding_keyword_file=kwargs["set_embedding_keyword_file"]) + + if kwargs["clear"]: + config_kb_workspace.clear() + if kwargs["show"]: + print(config_kb_workspace.get_config()) if __name__ == "__main__": diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index 6a68fb03..19b2e685 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -820,7 +820,7 @@ __all__ = [ "ConfigKbFactory", "ConfigKbWorkSpace", - "config_model_workspace", + "config_kb_workspace", "ConfigServer", "ConfigServerFactory", diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py index 062e35bb..325574e8 100644 --- a/libs/chatchat-server/tests/unit_tests/config/test_config.py +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -142,4 +142,46 @@ def test_config_kb_workspace(): config_kb_workspace.set_default_knowledge_base(kb_name="test") config_kb_workspace.set_default_vs_type(vs_type="tes") - \ No newline at end of file + config_kb_workspace.set_cached_vs_num(cached_vs_num=10) + config_kb_workspace.set_cached_memo_vs_num(cached_memo_vs_num=10) + config_kb_workspace.set_chunk_size(chunk_size=10) + config_kb_workspace.set_overlap_size(overlap_size=10) + config_kb_workspace.set_vector_search_top_k(vector_search_top_k=10) + config_kb_workspace.set_score_threshold(score_threshold=0.1) + config_kb_workspace.set_default_search_engine(default_search_engine="test") + config_kb_workspace.set_search_engine_top_k(search_engine_top_k=10) + config_kb_workspace.set_zh_title_enhance(zh_title_enhance=True) + config_kb_workspace.set_pdf_ocr_threshold(pdf_ocr_threshold=(0.1, 0.2)) + config_kb_workspace.set_kb_info(kb_info={ + "samples": "关于本项目issue的解答", + }) + config_kb_workspace.set_kb_root_path(kb_root_path="test") + config_kb_workspace.set_db_root_path(db_root_path="test") + config_kb_workspace.set_sqlalchemy_database_uri(sqlalchemy_database_uri="test") + config_kb_workspace.set_text_splitter_name(text_splitter_name="test") + config_kb_workspace.set_embedding_keyword_file(embedding_keyword_file="test") + + config: ConfigKb = config_kb_workspace.get_config() + assert config.DEFAULT_KNOWLEDGE_BASE == "test" + assert config.DEFAULT_VS_TYPE == "tes" + assert config.CACHED_VS_NUM == 10 + assert config.CACHED_MEMO_VS_NUM == 10 + assert config.CHUNK_SIZE == 10 + assert config.OVERLAP_SIZE == 10 + assert config.VECTOR_SEARCH_TOP_K == 10 + assert config.SCORE_THRESHOLD == 0.1 + assert config.DEFAULT_SEARCH_ENGINE == "test" + assert config.SEARCH_ENGINE_TOP_K == 10 + assert config.ZH_TITLE_ENHANCE is True + assert config.PDF_OCR_THRESHOLD == (0.1, 0.2) + assert config.KB_INFO == { + "samples": "关于本项目issue的解答", + } + assert config.KB_ROOT_PATH == "test" + assert config.DB_ROOT_PATH == "test" + assert config.SQLALCHEMY_DATABASE_URI == "test" + assert config.TEXT_SPLITTER_NAME == "test" + assert config.EMBEDDING_KEYWORD_FILE == "test" + config_kb_workspace.clear() + +