配置中心知识库信息 子命令入口

This commit is contained in:
glide-the 2024-06-11 20:34:03 +08:00
parent 4f9d63d9f4
commit 91a345ce54
3 changed files with 130 additions and 12 deletions

View File

@ -1,8 +1,10 @@
from chatchat.configs import (
config_basic_workspace,
config_model_workspace,
config_server_workspace,
config_kb_workspace,
)
import ast
# We cannot lazy-load click here because its used via decorators.
import click
@ -71,7 +73,9 @@ def model(**kwargs):
config_model_workspace.set_support_agent_models(support_agent_models=kwargs["support_agent_models"])
if kwargs["model_providers_cfg_path_config"]:
config_model_workspace.set_model_providers_cfg_path_config(model_providers_cfg_path_config=kwargs["model_providers_cfg_path_config"])
config_model_workspace.set_model_providers_cfg_path_config(
model_providers_cfg_path_config=kwargs["model_providers_cfg_path_config"]
)
if kwargs["model_providers_cfg_host"]:
config_model_workspace.set_model_providers_cfg_host(model_providers_cfg_host=kwargs["model_providers_cfg_host"])
@ -96,25 +100,97 @@ def model(**kwargs):
def server(**kwargs):
if kwargs["httpx_default_timeout"]:
config_basic_workspace.set_httpx_default_timeout(httpx_default_timeout=kwargs["httpx_default_timeout"])
config_server_workspace.set_httpx_default_timeout(httpx_default_timeout=kwargs["httpx_default_timeout"])
if kwargs["open_cross_domain"]:
if kwargs["open_cross_domain"].lower() == "true":
config_basic_workspace.set_open_cross_domain(True)
config_server_workspace.set_open_cross_domain(True)
else:
config_basic_workspace.set_open_cross_domain(False)
config_server_workspace.set_open_cross_domain(False)
if kwargs["default_bind_host"]:
config_basic_workspace.set_default_bind_host(default_bind_host=kwargs["default_bind_host"])
config_server_workspace.set_default_bind_host(default_bind_host=kwargs["default_bind_host"])
if kwargs["webui_server_port"]:
config_basic_workspace.set_webui_server_port(webui_server_port=kwargs["webui_server_port"])
config_server_workspace.set_webui_server_port(webui_server_port=kwargs["webui_server_port"])
if kwargs["api_server_port"]:
config_basic_workspace.set_api_server_port(api_server_port=kwargs["api_server_port"])
config_server_workspace.set_api_server_port(api_server_port=kwargs["api_server_port"])
if kwargs["clear"]:
config_model_workspace.clear()
config_server_workspace.clear()
if kwargs["show"]:
print(config_model_workspace.get_config())
print(config_server_workspace.get_config())
@main.command("kb", help="知识库配置")
@click.option("--set_default_knowledge_base", help="设置默认知识库")
@click.option("--set_default_vs_type", help="设置默认vs类型")
@click.option("--set_cached_vs_num", type=int, help="设置缓存vs数量")
@click.option("--set_cached_memo_vs_num", type=int, help="设置缓存memo vs数量")
@click.option("--set_chunk_size", type=int, help="设置chunk大小")
@click.option("--set_overlap_size", type=int, help="设置overlap大小")
@click.option("--set_vector_search_top_k", type=int, help="设置vector search top k")
@click.option("--set_score_threshold", type=float, help="设置score阈值")
@click.option("--set_default_search_engine", help="设置默认搜索引擎")
@click.option("--set_search_engine_top_k", type=int, help="设置搜索引擎top k")
@click.option("--set_zh_title_enhance", type=click.Choice(["true", "false"]), help="是否开启中文标题增强")
@click.option('--pdf-ocr-threshold', type=(float, float), help='pdf ocr threshold')
@click.option('--set_kb_info', type=str, help='''每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用
没写则没有介绍不会被Agent调用
Example: \'{"samples": "关于本项目issue的解答"}\'
''')
@click.option("--set_kb_root_path", help="设置知识库根路径")
@click.option("--set_db_root_path", help="设置db根路径")
@click.option("--set_sqlalchemy_database_uri", help="设置sqlalchemy数据库uri")
@click.option("--set_text_splitter_name", help="设置text splitter名称")
@click.option("--set_embedding_keyword_file", help="设置embedding关键词文件")
@click.option("--clear", is_flag=True, help="清除配置")
@click.option("--show", is_flag=True, help="显示配置")
def kb(**kwargs):
if kwargs["set_default_knowledge_base"]:
config_kb_workspace.set_default_knowledge_base(default_knowledge_base=kwargs["set_default_knowledge_base"])
if kwargs["set_default_vs_type"]:
config_kb_workspace.set_default_vs_type(default_vs_type=kwargs["set_default_vs_type"])
if kwargs["set_cached_vs_num"]:
config_kb_workspace.set_cached_vs_num(cached_vs_num=kwargs["set_cached_vs_num"])
if kwargs["set_cached_memo_vs_num"]:
config_kb_workspace.set_cached_memo_vs_num(cached_memo_vs_num=kwargs["set_cached_memo_vs_num"])
if kwargs["set_chunk_size"]:
config_kb_workspace.set_chunk_size(chunk_size=kwargs["set_chunk_size"])
if kwargs["set_overlap_size"]:
config_kb_workspace.set_overlap_size(overlap_size=kwargs["set_overlap_size"])
if kwargs["set_vector_search_top_k"]:
config_kb_workspace.set_vector_search_top_k(vector_search_top_k=kwargs["set_vector_search_top_k"])
if kwargs["set_score_threshold"]:
config_kb_workspace.set_score_threshold(score_threshold=kwargs["set_score_threshold"])
if kwargs["set_default_search_engine"]:
config_kb_workspace.set_default_search_engine(default_search_engine=kwargs["set_default_search_engine"])
if kwargs["set_search_engine_top_k"]:
config_model_workspace.set_search_engine_top_k(search_engine_top_k=kwargs["set_search_engine_top_k"])
if kwargs["set_zh_title_enhance"]:
if kwargs["set_zh_title_enhance"].lower() == "true":
config_kb_workspace.set_zh_title_enhance(True)
else:
config_kb_workspace.set_zh_title_enhance(False)
if kwargs["pdf_ocr_threshold"]:
config_kb_workspace.set_pdf_ocr_threshold(pdf_ocr_threshold=kwargs["pdf_ocr_threshold"])
if kwargs["set_kb_info"]:
config_kb_workspace.set_kb_info(kb_info=ast.literal_eval(kwargs["set_kb_info"]))
if kwargs["set_kb_root_path"]:
config_kb_workspace.set_kb_root_path(kb_root_path=kwargs["set_kb_root_path"])
if kwargs["set_db_root_path"]:
config_kb_workspace.set_db_root_path(db_root_path=kwargs["set_db_root_path"])
if kwargs["set_sqlalchemy_database_uri"]:
config_kb_workspace.set_sqlalchemy_database_uri(sqlalchemy_database_uri=kwargs["set_sqlalchemy_database_uri"])
if kwargs["set_text_splitter_name"]:
config_kb_workspace.set_text_splitter_name(text_splitter_name=kwargs["set_text_splitter_name"])
if kwargs["set_embedding_keyword_file"]:
config_kb_workspace.set_embedding_keyword_file(embedding_keyword_file=kwargs["set_embedding_keyword_file"])
if kwargs["clear"]:
config_kb_workspace.clear()
if kwargs["show"]:
print(config_kb_workspace.get_config())
if __name__ == "__main__":

View File

@ -820,7 +820,7 @@ __all__ = [
"ConfigKbFactory",
"ConfigKbWorkSpace",
"config_model_workspace",
"config_kb_workspace",
"ConfigServer",
"ConfigServerFactory",

View File

@ -142,4 +142,46 @@ def test_config_kb_workspace():
config_kb_workspace.set_default_knowledge_base(kb_name="test")
config_kb_workspace.set_default_vs_type(vs_type="tes")
config_kb_workspace.set_cached_vs_num(cached_vs_num=10)
config_kb_workspace.set_cached_memo_vs_num(cached_memo_vs_num=10)
config_kb_workspace.set_chunk_size(chunk_size=10)
config_kb_workspace.set_overlap_size(overlap_size=10)
config_kb_workspace.set_vector_search_top_k(vector_search_top_k=10)
config_kb_workspace.set_score_threshold(score_threshold=0.1)
config_kb_workspace.set_default_search_engine(default_search_engine="test")
config_kb_workspace.set_search_engine_top_k(search_engine_top_k=10)
config_kb_workspace.set_zh_title_enhance(zh_title_enhance=True)
config_kb_workspace.set_pdf_ocr_threshold(pdf_ocr_threshold=(0.1, 0.2))
config_kb_workspace.set_kb_info(kb_info={
"samples": "关于本项目issue的解答",
})
config_kb_workspace.set_kb_root_path(kb_root_path="test")
config_kb_workspace.set_db_root_path(db_root_path="test")
config_kb_workspace.set_sqlalchemy_database_uri(sqlalchemy_database_uri="test")
config_kb_workspace.set_text_splitter_name(text_splitter_name="test")
config_kb_workspace.set_embedding_keyword_file(embedding_keyword_file="test")
config: ConfigKb = config_kb_workspace.get_config()
assert config.DEFAULT_KNOWLEDGE_BASE == "test"
assert config.DEFAULT_VS_TYPE == "tes"
assert config.CACHED_VS_NUM == 10
assert config.CACHED_MEMO_VS_NUM == 10
assert config.CHUNK_SIZE == 10
assert config.OVERLAP_SIZE == 10
assert config.VECTOR_SEARCH_TOP_K == 10
assert config.SCORE_THRESHOLD == 0.1
assert config.DEFAULT_SEARCH_ENGINE == "test"
assert config.SEARCH_ENGINE_TOP_K == 10
assert config.ZH_TITLE_ENHANCE is True
assert config.PDF_OCR_THRESHOLD == (0.1, 0.2)
assert config.KB_INFO == {
"samples": "关于本项目issue的解答",
}
assert config.KB_ROOT_PATH == "test"
assert config.DB_ROOT_PATH == "test"
assert config.SQLALCHEMY_DATABASE_URI == "test"
assert config.TEXT_SPLITTER_NAME == "test"
assert config.EMBEDDING_KEYWORD_FILE == "test"
config_kb_workspace.clear()