From e7a5d6a5288062a2d4fd2294ad8aa026dd05b3fa Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Tue, 11 Jun 2024 15:14:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8click=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E4=B8=AD=E5=BF=83=E5=AD=90=E5=91=BD=E4=BB=A4?= =?UTF-8?q?=20(#4164)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 使用click增加配置中心子命令 新增 ConfigModelWorkSpace --- .../chatchat/config_work_space.py | 122 ++-- .../chatchat/configs/__init__.py | 101 ++- .../chatchat/configs/_basic_config.py | 11 +- .../chatchat/configs/_core_config.py | 40 +- .../chatchat/configs/_model_config.py | 664 +++++++++++------- libs/chatchat-server/pyproject.toml | 2 +- .../tests/unit_tests/config/test_config.py | 61 +- 7 files changed, 678 insertions(+), 323 deletions(-) diff --git a/libs/chatchat-server/chatchat/config_work_space.py b/libs/chatchat-server/chatchat/config_work_space.py index 19b08d49..c8689816 100644 --- a/libs/chatchat-server/chatchat/config_work_space.py +++ b/libs/chatchat-server/chatchat/config_work_space.py @@ -1,52 +1,88 @@ -from chatchat.configs import config_basic_workspace as workspace +from chatchat.configs import ( + config_basic_workspace, + config_model_workspace, +) + +# We cannot lazy-load click here because its used via decorators. +import click +@click.group(help="指令` chatchat-config` 工作空间配置") def main(): - import argparse + pass - parser = argparse.ArgumentParser(description="指令` chatchat-config` 工作空间配置") - # 只能选择true或false - parser.add_argument( - "-v", - "--verbose", - choices=["true", "false"], - help="是否开启详细日志" - ) - parser.add_argument( - "-d", - "--data", - help="数据存放路径" - ) - parser.add_argument( - "-f", - "--format", - help="日志格式" - ) - parser.add_argument( - "--clear", - action="store_true", - help="清除配置" - ) - parser.add_argument( - "--show", - action="store_true", - help="显示配置" - ) - args = parser.parse_args() - if args.verbose: - if args.verbose.lower() == "true": - workspace.set_log_verbose(True) +@main.command("basic", help="基础配置") +@click.option("--verbose", type=click.Choice(["true", "false"]), help="是否开启详细日志") +@click.option("--data", help="数据存放路径") +@click.option("--format", help="日志格式") +@click.option("--clear", is_flag=True, help="清除配置") +@click.option("--show", is_flag=True, help="显示配置") +def basic(**kwargs): + + if kwargs["verbose"]: + if kwargs["verbose"].lower() == "true": + config_basic_workspace.set_log_verbose(True) else: - workspace.set_log_verbose(False) - if args.data: - workspace.set_data_path(args.data) - if args.format: - workspace.set_log_format(args.format) - if args.clear: - workspace.clear() - if args.show: - print(workspace.get_config()) + config_basic_workspace.set_log_verbose(False) + if kwargs["data"]: + config_basic_workspace.set_data_path(kwargs["data"]) + if kwargs["format"]: + config_basic_workspace.set_log_format(kwargs["format"]) + if kwargs["clear"]: + config_basic_workspace.clear() + if kwargs["show"]: + print(config_basic_workspace.get_config()) + + +@main.command("model", help="模型配置") +@click.option("--default_llm_model", help="默认llm模型") +@click.option("--default_embedding_model", help="默认embedding模型") +@click.option("--agent_model", help="agent模型") +@click.option("--history_len", type=int, help="历史长度") +@click.option("--max_tokens", type=int, help="最大tokens") +@click.option("--temperature", type=float, help="温度") +@click.option("--support_agent_models", multiple=True, help="支持的agent模型") +@click.option("--model_providers_cfg_path_config", help="模型平台配置文件路径") +@click.option("--model_providers_cfg_host", help="模型平台配置服务host") +@click.option("--model_providers_cfg_port", type=int, help="模型平台配置服务port") +@click.option("--clear", is_flag=True, help="清除配置") +@click.option("--show", is_flag=True, help="显示配置") +def model(**kwargs): + + if kwargs["default_llm_model"]: + config_model_workspace.set_default_llm_model(llm_model=kwargs["default_llm_model"]) + if kwargs["default_embedding_model"]: + config_model_workspace.set_default_embedding_model(embedding_model=kwargs["default_embedding_model"]) + + if kwargs["agent_model"]: + config_model_workspace.set_agent_model(agent_model=kwargs["agent_model"]) + + if kwargs["history_len"]: + config_model_workspace.set_history_len(history_len=kwargs["history_len"]) + + if kwargs["max_tokens"]: + config_model_workspace.set_max_tokens(max_tokens=kwargs["max_tokens"]) + + if kwargs["temperature"]: + config_model_workspace.set_temperature(temperature=kwargs["temperature"]) + + if kwargs["support_agent_models"]: + config_model_workspace.set_support_agent_models(support_agent_models=kwargs["support_agent_models"]) + + if kwargs["model_providers_cfg_path_config"]: + config_model_workspace.set_model_providers_cfg_path_config(model_providers_cfg_path_config=kwargs["model_providers_cfg_path_config"]) + + if kwargs["model_providers_cfg_host"]: + config_model_workspace.set_model_providers_cfg_host(model_providers_cfg_host=kwargs["model_providers_cfg_host"]) + + if kwargs["model_providers_cfg_port"]: + config_model_workspace.set_model_providers_cfg_port(model_providers_cfg_port=kwargs["model_providers_cfg_port"]) + + if kwargs["clear"]: + config_model_workspace.clear() + if kwargs["show"]: + print(config_model_workspace.get_config()) if __name__ == "__main__": diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index b2408456..c04e5e33 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -363,108 +363,140 @@ def _import_embedding_keyword_file() -> Any: return EMBEDDING_KEYWORD_FILE +def _import_ConfigModel() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_model_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigModel = load_mod(basic_config_load.get("module"), "ConfigModel") + + return ConfigModel + + +def _import_ConfigModelFactory() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_model_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigModelFactory = load_mod(basic_config_load.get("module"), "ConfigModelFactory") + + return ConfigModelFactory + + +def _import_ConfigModelWorkSpace() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_model_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigModelWorkSpace = load_mod(basic_config_load.get("module"), "ConfigModelWorkSpace") + + return ConfigModelWorkSpace + + +def _import_config_model_workspace() -> Any: + model_config_load = CONFIG_IMPORTS.get("_model_config.py") + load_mod = model_config_load.get("load_mod") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + return config_model_workspace + + def _import_default_llm_model() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - DEFAULT_LLM_MODEL = load_mod(model_config_load.get("module"), "DEFAULT_LLM_MODEL") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return DEFAULT_LLM_MODEL + return config_model_workspace.get_config().DEFAULT_LLM_MODEL def _import_default_embedding_model() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - DEFAULT_EMBEDDING_MODEL = load_mod(model_config_load.get("module"), "DEFAULT_EMBEDDING_MODEL") - return DEFAULT_EMBEDDING_MODEL + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") + + return config_model_workspace.get_config().DEFAULT_EMBEDDING_MODEL def _import_agent_model() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - Agent_MODEL = load_mod(model_config_load.get("module"), "Agent_MODEL") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return Agent_MODEL + return config_model_workspace.get_config().Agent_MODEL def _import_history_len() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - HISTORY_LEN = load_mod(model_config_load.get("module"), "HISTORY_LEN") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return HISTORY_LEN + return config_model_workspace.get_config().HISTORY_LEN def _import_max_tokens() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - MAX_TOKENS = load_mod(model_config_load.get("module"), "MAX_TOKENS") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return MAX_TOKENS + return config_model_workspace.get_config().MAX_TOKENS def _import_temperature() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - TEMPERATURE = load_mod(model_config_load.get("module"), "TEMPERATURE") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return TEMPERATURE + return config_model_workspace.get_config().TEMPERATURE def _import_support_agent_models() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - SUPPORT_AGENT_MODELS = load_mod(model_config_load.get("module"), "SUPPORT_AGENT_MODELS") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return SUPPORT_AGENT_MODELS + return config_model_workspace.get_config().SUPPORT_AGENT_MODELS def _import_llm_model_config() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - LLM_MODEL_CONFIG = load_mod(model_config_load.get("module"), "LLM_MODEL_CONFIG") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return LLM_MODEL_CONFIG + return config_model_workspace.get_config().LLM_MODEL_CONFIG def _import_model_platforms() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - MODEL_PLATFORMS = load_mod(model_config_load.get("module"), "MODEL_PLATFORMS") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return MODEL_PLATFORMS + return config_model_workspace.get_config().MODEL_PLATFORMS def _import_model_providers_cfg_path() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - MODEL_PROVIDERS_CFG_PATH_CONFIG = load_mod(model_config_load.get("module"), "MODEL_PROVIDERS_CFG_PATH_CONFIG") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return MODEL_PROVIDERS_CFG_PATH_CONFIG + return config_model_workspace.get_config().MODEL_PROVIDERS_CFG_PATH_CONFIG def _import_model_providers_cfg_host() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - MODEL_PROVIDERS_CFG_HOST = load_mod(model_config_load.get("module"), "MODEL_PROVIDERS_CFG_HOST") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return MODEL_PROVIDERS_CFG_HOST + return config_model_workspace.get_config().MODEL_PROVIDERS_CFG_HOST def _import_model_providers_cfg_port() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - MODEL_PROVIDERS_CFG_PORT = load_mod(model_config_load.get("module"), "MODEL_PROVIDERS_CFG_PORT") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return MODEL_PROVIDERS_CFG_PORT + return config_model_workspace.get_config().MODEL_PROVIDERS_CFG_PORT def _import_tool_config() -> Any: model_config_load = CONFIG_IMPORTS.get("_model_config.py") load_mod = model_config_load.get("load_mod") - TOOL_CONFIG = load_mod(model_config_load.get("module"), "TOOL_CONFIG") + config_model_workspace = load_mod(model_config_load.get("module"), "config_model_workspace") - return TOOL_CONFIG + return config_model_workspace.get_config().TOOL_CONFIG def _import_prompt_templates() -> Any: @@ -524,6 +556,14 @@ def __getattr__(name: str) -> Any: return _import_ConfigBasicWorkSpace() elif name == "config_basic_workspace": return _import_config_basic_workspace() + elif name == "ConfigModel": + return _import_ConfigModel() + elif name == "ConfigModelFactory": + return _import_ConfigModelFactory() + elif name == "ConfigModelWorkSpace": + return _import_ConfigModelWorkSpace() + elif name == "config_model_workspace": + return _import_config_model_workspace() elif name == "log_verbose": return _import_log_verbose() elif name == "CHATCHAT_ROOT": @@ -624,7 +664,6 @@ VERSION = "v0.3.0-preview" __all__ = [ "VERSION", - "config_basic_workspace", "log_verbose", "CHATCHAT_ROOT", "DATA_PATH", @@ -677,4 +716,12 @@ __all__ = [ "ConfigBasicFactory", "ConfigBasicWorkSpace", + "config_basic_workspace", + + "ConfigModel", + "ConfigModelFactory", + "ConfigModelWorkSpace", + + "config_model_workspace", + ] diff --git a/libs/chatchat-server/chatchat/configs/_basic_config.py b/libs/chatchat-server/chatchat/configs/_basic_config.py index f6c587c1..f53c20db 100644 --- a/libs/chatchat-server/chatchat/configs/_basic_config.py +++ b/libs/chatchat-server/chatchat/configs/_basic_config.py @@ -6,7 +6,6 @@ import sys import logging from typing import Any, Optional -from chatchat.configs._core_config import CF sys.path.append(str(Path(__file__).parent)) import _core_config as core_config @@ -128,6 +127,9 @@ class ConfigBasicWorkSpace(core_config.ConfigWorkSpace[ConfigBasicFactory, Confi """ config_factory_cls = ConfigBasicFactory + def __init__(self): + super().__init__() + def _build_config_factory(self, config_json: Any) -> ConfigBasicFactory: _config_factory = self.config_factory_cls() @@ -145,9 +147,6 @@ class ConfigBasicWorkSpace(core_config.ConfigWorkSpace[ConfigBasicFactory, Confi def get_type(cls) -> str: return ConfigBasic.class_name() - def __init__(self): - super().__init__() - def get_config(self) -> ConfigBasic: return self._config_factory.get_config() @@ -163,9 +162,5 @@ class ConfigBasicWorkSpace(core_config.ConfigWorkSpace[ConfigBasicFactory, Confi self._config_factory.log_format(log_format) self.store_config() - def clear(self): - logger.info("Clear workspace config.") - os.remove(self.workspace_config) - config_basic_workspace: ConfigBasicWorkSpace = ConfigBasicWorkSpace() diff --git a/libs/chatchat-server/chatchat/configs/_core_config.py b/libs/chatchat-server/chatchat/configs/_core_config.py index 5059ff3a..339c3dfa 100644 --- a/libs/chatchat-server/chatchat/configs/_core_config.py +++ b/libs/chatchat-server/chatchat/configs/_core_config.py @@ -62,15 +62,13 @@ class ConfigWorkSpace(Generic[CF, F], ABC): self.workspace_config = os.path.join(self.workspace, "workspace_config.json") # 初始化工作空间配置,转换成json格式,实现Config的实例化 - config_type_json = self._load_config() - if config_type_json is None: + _load_config = self._load_config() + if _load_config is None: self._config_factory = self._build_config_factory(config_json={}) self.store_config() else: - config_type = config_type_json.get("type", None) - if self.get_type() != config_type: - raise ValueError(f"Config type mismatch: {self.get_type()} != {config_type}") + config_type_json = self.get_config_by_type(self.get_type()) config_json = config_type_json.get("config") self._config_factory = self._build_config_factory(config_json) @@ -98,9 +96,39 @@ class ConfigWorkSpace(Generic[CF, F], ABC): except FileNotFoundError: return None + @staticmethod + def _get_store_cfg_index_by_type(store_cfg, store_cfg_type) -> int: + if store_cfg is None: + raise RuntimeError("store_cfg is None.") + for cfg in store_cfg: + if cfg.get("type") == store_cfg_type: + return store_cfg.index(cfg) + + return -1 + + def get_config_by_type(self, cfg_type) -> Dict[str, Any]: + store_cfg = self._load_config() + if store_cfg is None: + raise RuntimeError("store_cfg is None.") + + get_lambda = lambda store_cfg_type: store_cfg[self._get_store_cfg_index_by_type(store_cfg, store_cfg_type)] + return get_lambda(cfg_type) + def store_config(self): logger.info("Store workspace config.") + _load_config = self._load_config() with open(self.workspace_config, "w") as f: config_json = self.get_config().to_dict() + + if _load_config is None: + _load_config = [] + config_json_index = self._get_store_cfg_index_by_type( + store_cfg=_load_config, + store_cfg_type=self.get_type() + ) config_type_json = {"type": self.get_type(), "config": config_json} - f.write(json.dumps(config_type_json, indent=4, ensure_ascii=False)) + if config_json_index == -1: + _load_config.append(config_type_json) + else: + _load_config[config_json_index] = config_type_json + f.write(json.dumps(_load_config, indent=4, ensure_ascii=False)) diff --git a/libs/chatchat-server/chatchat/configs/_model_config.py b/libs/chatchat-server/chatchat/configs/_model_config.py index afb0e8ea..c7bbdc54 100644 --- a/libs/chatchat-server/chatchat/configs/_model_config.py +++ b/libs/chatchat-server/chatchat/configs/_model_config.py @@ -1,260 +1,450 @@ import os +import logging +import sys +from pathlib import Path +from typing import Any, Optional, List, Dict -# 默认选用的 LLM 名称 -DEFAULT_LLM_MODEL = "chatglm3-6b" +from dataclasses import dataclass -# 默认选用的 Embedding 名称 -DEFAULT_EMBEDDING_MODEL = "bge-large-zh-v1.5" +sys.path.append(str(Path(__file__).parent)) +import _core_config as core_config + +logger = logging.getLogger() -# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0]) -Agent_MODEL = None +class ConfigModel(core_config.Config): + DEFAULT_LLM_MODEL: Optional[str] = None + """默认选用的 LLM 名称""" + DEFAULT_EMBEDDING_MODEL: Optional[str] = None + """默认选用的 Embedding 名称""" + Agent_MODEL: Optional[str] = None + """AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0])""" + HISTORY_LEN: Optional[int] = None + """历史对话轮数""" + MAX_TOKENS: Optional[int] = None + """大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度""" + TEMPERATURE: Optional[float] = None + """LLM通用对话参数""" + SUPPORT_AGENT_MODELS: Optional[List[str]] = None + """支持的Agent模型""" + LLM_MODEL_CONFIG: Optional[Dict[str, Dict[str, Any]]] = None + """LLM模型配置,包括了不同模态初始化参数""" + MODEL_PLATFORMS: Optional[List[Dict[str, Any]]] = None + """模型平台配置""" + MODEL_PROVIDERS_CFG_PATH_CONFIG: Optional[str] = None + """模型平台配置文件路径""" + MODEL_PROVIDERS_CFG_HOST: Optional[str] = None + """模型平台配置服务host""" + MODEL_PROVIDERS_CFG_PORT: Optional[int] = None + """模型平台配置服务port""" + TOOL_CONFIG: Optional[Dict[str, Any]] = None + """工具配置项""" -# 历史对话轮数 -HISTORY_LEN = 3 + @classmethod + def class_name(cls) -> str: + return cls.__name__ -# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度 -MAX_TOKENS = None - -# LLM通用对话参数 -TEMPERATURE = 0.7 -# TOP_P = 0.95 # ChatOpenAI暂不支持该参数 - -SUPPORT_AGENT_MODELS = [ - "chatglm3-6b", - "openai-api", - "Qwen-14B-Chat", - "Qwen-7B-Chat", - "qwen-turbo", -] + def __str__(self): + return self.to_json() -LLM_MODEL_CONFIG = { - # 意图识别不需要输出,模型后台知道就行 - "preprocess_model": { - DEFAULT_LLM_MODEL: { - "temperature": 0.05, - "max_tokens": 4096, - "history_len": 100, - "prompt_name": "default", - "callbacks": False - }, - }, - "llm_model": { - DEFAULT_LLM_MODEL: { - "temperature": 0.9, - "max_tokens": 4096, - "history_len": 10, - "prompt_name": "default", - "callbacks": True - }, - }, - "action_model": { - DEFAULT_LLM_MODEL: { - "temperature": 0.01, - "max_tokens": 4096, - "prompt_name": "ChatGLM3", - "callbacks": True - }, - }, - "postprocess_model": { - DEFAULT_LLM_MODEL: { - "temperature": 0.01, - "max_tokens": 4096, - "prompt_name": "default", - "callbacks": True - } - }, - "image_model": { - "sd-turbo": { - "size": "256*256", - } - } -} +@dataclass +class ConfigModelFactory(core_config.ConfigFactory[ConfigModel]): + """ConfigModel工厂类""" -# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力,启动后下面变量会自动增加相应的平台 -# ### 如果您已经有了一个openai endpoint的能力的地址,可以在这里直接配置 -# - platform_name 可以任意填写,不要重复即可 -# - platform_type 以后可能根据平台类型做一些功能区分,与platform_name一致即可 -# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。 + def __init__(self): + # 默认选用的 LLM 名称 + self.DEFAULT_LLM_MODEL = "chatglm3-6b" + # 默认选用的 Embedding 名称 + self.DEFAULT_EMBEDDING_MODEL = "bge-large-zh-v1.5" -# 创建一个全局的共享字典 -MODEL_PLATFORMS = [ + # AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0]) + self.Agent_MODEL = None - { - "platform_name": "oneapi", - "platform_type": "oneapi", - "api_base_url": "http://127.0.0.1:3000/v1", - "api_key": "sk-", - "api_concurrencies": 5, - "llm_models": [ - # 智谱 API - "chatglm_pro", - "chatglm_turbo", - "chatglm_std", - "chatglm_lite", - # 千问 API + # 历史对话轮数 + self.HISTORY_LEN = 3 + + # 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度 + self.MAX_TOKENS = None + + # LLM通用对话参数 + self.TEMPERATURE = 0.7 + # TOP_P = 0.95 # ChatOpenAI暂不支持该参数 + + self.SUPPORT_AGENT_MODELS = [ + "chatglm3-6b", + "openai-api", + "Qwen-14B-Chat", + "Qwen-7B-Chat", "qwen-turbo", - "qwen-plus", - "qwen-max", - "qwen-max-longcontext", - # 千帆 API - "ERNIE-Bot", - "ERNIE-Bot-turbo", - "ERNIE-Bot-4", - # 星火 API - "SparkDesk", - ], - "embed_models": [ - # 千问 API - "text-embedding-v1", - # 千帆 API - "Embedding-V1", - ], - "image_models": [], - "reranking_models": [], - "speech2text_models": [], - "tts_models": [], - }, + ] - { - "platform_name": "xinference", - "platform_type": "xinference", - "api_base_url": "http://127.0.0.1:9997/v1", - "api_key": "EMPTY", - "api_concurrencies": 5, - "llm_models": [ - "glm-4", - "qwen2-instruct", - "qwen1.5-chat", - ], - "embed_models": [ - "bge-large-zh-v1.5", - ], - "image_models": [], - "reranking_models": [], - "speech2text_models": [], - "tts_models": [], - }, + self.MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), + "model_providers.yaml") + self.MODEL_PROVIDERS_CFG_HOST = "127.0.0.1" -] + self.MODEL_PROVIDERS_CFG_PORT = 20000 -MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_providers.yaml") -MODEL_PROVIDERS_CFG_HOST = "127.0.0.1" + self._init_llm_work_config() -MODEL_PROVIDERS_CFG_PORT = 20000 -# 工具配置项 -TOOL_CONFIG = { - "search_local_knowledgebase": { - "use": False, - "top_k": 3, - "score_threshold": 1.0, - "conclude_prompt": { - "with_result": - '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' - '不允许在答案中添加编造成分,答案请使用中文。 \n' - '<已知信息>{{ context }}\n' - '<问题>{{ question }}\n', - "without_result": - '请你根据我的提问回答我的问题:\n' - '{{ question }}\n' - '请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n', - } - }, - "search_internet": { - "use": False, - "search_engine_name": "bing", - "search_engine_config": - { - "bing": { - "result_len": 3, - "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", - "bing_key": "", + def _init_llm_work_config(self): + """初始化知识库runtime的一些配置""" + + self.LLM_MODEL_CONFIG = { + # 意图识别不需要输出,模型后台知道就行 + "preprocess_model": { + self.DEFAULT_LLM_MODEL: { + "temperature": 0.05, + "max_tokens": 4096, + "history_len": 100, + "prompt_name": "default", + "callbacks": False }, - "metaphor": { - "result_len": 3, - "metaphor_api_key": "", - "split_result": False, - "chunk_size": 500, - "chunk_overlap": 0, + }, + "llm_model": { + self.DEFAULT_LLM_MODEL: { + "temperature": 0.9, + "max_tokens": 4096, + "history_len": 10, + "prompt_name": "default", + "callbacks": True }, - "duckduckgo": { - "result_len": 3 + }, + "action_model": { + self.DEFAULT_LLM_MODEL: { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "ChatGLM3", + "callbacks": True + }, + }, + "postprocess_model": { + self.DEFAULT_LLM_MODEL: { + "temperature": 0.01, + "max_tokens": 4096, + "prompt_name": "default", + "callbacks": True } }, - "top_k": 10, - "verbose": "Origin", - "conclude_prompt": - "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " - "\n<已知信息>{{ context }}\n" - "<问题>\n" - "{{ question }}\n" - "\n" - }, - "arxiv": { - "use": False, - }, - "shell": { - "use": False, - }, - "weather_check": { - "use": False, - "api_key": "S8vrB4U_-c5mvAMiK", - }, - "search_youtube": { - "use": False, - }, - "wolfram": { - "use": False, - "appid": "", - }, - "calculate": { - "use": False, - }, - "vqa_processor": { - "use": False, - "model_path": "your model path", - "tokenizer_path": "your tokenizer path", - "device": "cuda:1" - }, - "aqa_processor": { - "use": False, - "model_path": "your model path", - "tokenizer_path": "yout tokenizer path", - "device": "cuda:2" - }, - "text2images": { - "use": False, - }, - # text2sql使用建议 - # 1、因大模型生成的sql可能与预期有偏差,请务必在测试环境中进行充分测试、评估; - # 2、生产环境中,对于查询操作,由于不确定查询效率,推荐数据库采用主从数据库架构,让text2sql连接从数据库,防止可能的慢查询影响主业务; - # 3、对于写操作应保持谨慎,如不需要写操作,设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作; - # 4、text2sql与大模型在意图理解、sql转换等方面的能力有关,可切换不同大模型进行测试; - # 5、数据库表名、字段名应与其实际作用保持一致、容易理解,且应对数据库表名、字段进行详细的备注说明,帮助大模型更好理解数据库结构; - # 6、若现有数据库表名难于让大模型理解,可配置下面table_comments字段,补充说明某些表的作用。 - "text2sql": { - "use": False, - # SQLAlchemy连接字符串,支持的数据库有: - # crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb - # 不同的数据库请查询SQLAlchemy,修改sqlalchemy_connect_str,配置对应的数据库连接,如sqlite为sqlite:///数据库文件路径,下面示例为mysql - # 如提示缺少对应数据库的驱动,请自行通过poetry安装 - "sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e", - # 务必评估是否需要开启read_only,开启后会对sql语句进行检查,请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求 - # 优先推荐从数据库层面对用户权限进行限制 - "read_only": False, - #限定返回的行数 - "top_k":50, - #是否返回中间步骤 - "return_intermediate_steps": True, - #如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表 - "table_names":[], - #对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。 - "table_comments":{ - # 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明 - # "tableA":"这是一个用户表,存储了用户的基本信息", - # "tanleB":"角色表", + "image_model": { + "sd-turbo": { + "size": "256*256", + } + } } - }, -} + + # 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力,启动后下面变量会自动增加相应的平台 + # ### 如果您已经有了一个openai endpoint的能力的地址,可以在这里直接配置 + # - platform_name 可以任意填写,不要重复即可 + # - platform_type 以后可能根据平台类型做一些功能区分,与platform_name一致即可 + # - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。 + + # 创建一个全局的共享字典 + self.MODEL_PLATFORMS = [ + + { + "platform_name": "oneapi", + "platform_type": "oneapi", + "api_base_url": "http://127.0.0.1:3000/v1", + "api_key": "sk-", + "api_concurrencies": 5, + "llm_models": [ + # 智谱 API + "chatglm_pro", + "chatglm_turbo", + "chatglm_std", + "chatglm_lite", + # 千问 API + "qwen-turbo", + "qwen-plus", + "qwen-max", + "qwen-max-longcontext", + # 千帆 API + "ERNIE-Bot", + "ERNIE-Bot-turbo", + "ERNIE-Bot-4", + # 星火 API + "SparkDesk", + ], + "embed_models": [ + # 千问 API + "text-embedding-v1", + # 千帆 API + "Embedding-V1", + ], + "image_models": [], + "reranking_models": [], + "speech2text_models": [], + "tts_models": [], + }, + + { + "platform_name": "xinference", + "platform_type": "xinference", + "api_base_url": "http://127.0.0.1:9997/v1", + "api_key": "EMPTY", + "api_concurrencies": 5, + "llm_models": [ + "glm-4", + "qwen2-instruct", + "qwen1.5-chat", + ], + "embed_models": [ + "bge-large-zh-v1.5", + ], + "image_models": [], + "reranking_models": [], + "speech2text_models": [], + "tts_models": [], + }, + + ] + # 工具配置项 + self.TOOL_CONFIG = { + "search_local_knowledgebase": { + "use": False, + "top_k": 3, + "score_threshold": 1.0, + "conclude_prompt": { + "with_result": + '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' + '不允许在答案中添加编造成分,答案请使用中文。 \n' + '<已知信息>{{ context }}\n' + '<问题>{{ question }}\n', + "without_result": + '请你根据我的提问回答我的问题:\n' + '{{ question }}\n' + '请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n', + } + }, + "search_internet": { + "use": False, + "search_engine_name": "bing", + "search_engine_config": + { + "bing": { + "result_len": 3, + "bing_search_url": "https://api.bing.microsoft.com/v7.0/search", + "bing_key": "", + }, + "metaphor": { + "result_len": 3, + "metaphor_api_key": "", + "split_result": False, + "chunk_size": 500, + "chunk_overlap": 0, + }, + "duckduckgo": { + "result_len": 3 + } + }, + "top_k": 10, + "verbose": "Origin", + "conclude_prompt": + "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " + "\n<已知信息>{{ context }}\n" + "<问题>\n" + "{{ question }}\n" + "\n" + }, + "arxiv": { + "use": False, + }, + "shell": { + "use": False, + }, + "weather_check": { + "use": False, + "api_key": "S8vrB4U_-c5mvAMiK", + }, + "search_youtube": { + "use": False, + }, + "wolfram": { + "use": False, + "appid": "", + }, + "calculate": { + "use": False, + }, + "vqa_processor": { + "use": False, + "model_path": "your model path", + "tokenizer_path": "your tokenizer path", + "device": "cuda:1" + }, + "aqa_processor": { + "use": False, + "model_path": "your model path", + "tokenizer_path": "yout tokenizer path", + "device": "cuda:2" + }, + "text2images": { + "use": False, + }, + # text2sql使用建议 + # 1、因大模型生成的sql可能与预期有偏差,请务必在测试环境中进行充分测试、评估; + # 2、生产环境中,对于查询操作,由于不确定查询效率,推荐数据库采用主从数据库架构,让text2sql连接从数据库,防止可能的慢查询影响主业务; + # 3、对于写操作应保持谨慎,如不需要写操作,设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作; + # 4、text2sql与大模型在意图理解、sql转换等方面的能力有关,可切换不同大模型进行测试; + # 5、数据库表名、字段名应与其实际作用保持一致、容易理解,且应对数据库表名、字段进行详细的备注说明,帮助大模型更好理解数据库结构; + # 6、若现有数据库表名难于让大模型理解,可配置下面table_comments字段,补充说明某些表的作用。 + "text2sql": { + "use": False, + # SQLAlchemy连接字符串,支持的数据库有: + # crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb + # 不同的数据库请查询SQLAlchemy,修改sqlalchemy_connect_str,配置对应的数据库连接,如sqlite为sqlite:///数据库文件路径,下面示例为mysql + # 如提示缺少对应数据库的驱动,请自行通过poetry安装 + "sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e", + # 务必评估是否需要开启read_only,开启后会对sql语句进行检查,请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求 + # 优先推荐从数据库层面对用户权限进行限制 + "read_only": False, + # 限定返回的行数 + "top_k": 50, + # 是否返回中间步骤 + "return_intermediate_steps": True, + # 如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表 + "table_names": [], + # 对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。 + "table_comments": { + # 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明 + # "tableA":"这是一个用户表,存储了用户的基本信息", + # "tanleB":"角色表", + } + }, + } + + def default_llm_model(self, llm_model: str): + self.DEFAULT_LLM_MODEL = llm_model + + def default_embedding_model(self, embedding_model: str): + self.DEFAULT_EMBEDDING_MODEL = embedding_model + + def agent_model(self, agent_model: str): + self.Agent_MODEL = agent_model + + def history_len(self, history_len: int): + self.HISTORY_LEN = history_len + + def max_tokens(self, max_tokens: int): + self.MAX_TOKENS = max_tokens + + def temperature(self, temperature: float): + self.TEMPERATURE = temperature + + def support_agent_models(self, support_agent_models: List[str]): + self.SUPPORT_AGENT_MODELS = support_agent_models + + def model_providers_cfg_path_config(self, model_providers_cfg_path_config: str): + self.MODEL_PROVIDERS_CFG_PATH_CONFIG = model_providers_cfg_path_config + + def model_providers_cfg_host(self, model_providers_cfg_host: str): + self.MODEL_PROVIDERS_CFG_HOST = model_providers_cfg_host + + def model_providers_cfg_port(self, model_providers_cfg_port: int): + self.MODEL_PROVIDERS_CFG_PORT = model_providers_cfg_port + + def get_config(self) -> ConfigModel: + config = ConfigModel() + config.DEFAULT_LLM_MODEL = self.DEFAULT_LLM_MODEL + config.DEFAULT_EMBEDDING_MODEL = self.DEFAULT_EMBEDDING_MODEL + config.Agent_MODEL = self.Agent_MODEL + config.HISTORY_LEN = self.HISTORY_LEN + config.MAX_TOKENS = self.MAX_TOKENS + config.TEMPERATURE = self.TEMPERATURE + config.SUPPORT_AGENT_MODELS = self.SUPPORT_AGENT_MODELS + config.LLM_MODEL_CONFIG = self.LLM_MODEL_CONFIG + config.MODEL_PLATFORMS = self.MODEL_PLATFORMS + config.MODEL_PROVIDERS_CFG_PATH_CONFIG = self.MODEL_PROVIDERS_CFG_PATH_CONFIG + config.MODEL_PROVIDERS_CFG_HOST = self.MODEL_PROVIDERS_CFG_HOST + config.MODEL_PROVIDERS_CFG_PORT = self.MODEL_PROVIDERS_CFG_PORT + config.TOOL_CONFIG = self.TOOL_CONFIG + + return config + + +class ConfigModelWorkSpace(core_config.ConfigWorkSpace[ConfigModelFactory, ConfigModel]): + """ + 工作空间的配置预设, 提供ConfigModel建造方法产生实例。 + """ + config_factory_cls = ConfigModelFactory + + def __init__(self): + super().__init__() + + def _build_config_factory(self, config_json: Any) -> ConfigModelFactory: + + _config_factory = self.config_factory_cls() + if config_json.get("DEFAULT_LLM_MODEL"): + _config_factory.default_llm_model(config_json.get("DEFAULT_LLM_MODEL")) + if config_json.get("DEFAULT_EMBEDDING_MODEL"): + _config_factory.default_embedding_model(config_json.get("DEFAULT_EMBEDDING_MODEL")) + if config_json.get("Agent_MODEL"): + _config_factory.agent_model(config_json.get("Agent_MODEL")) + if config_json.get("HISTORY_LEN"): + _config_factory.history_len(config_json.get("HISTORY_LEN")) + if config_json.get("MAX_TOKENS"): + _config_factory.max_tokens(config_json.get("MAX_TOKENS")) + if config_json.get("TEMPERATURE"): + _config_factory.temperature(config_json.get("TEMPERATURE")) + if config_json.get("SUPPORT_AGENT_MODELS"): + _config_factory.support_agent_models(config_json.get("SUPPORT_AGENT_MODELS")) + if config_json.get("MODEL_PROVIDERS_CFG_PATH_CONFIG"): + _config_factory.model_providers_cfg_path_config(config_json.get("MODEL_PROVIDERS_CFG_PATH_CONFIG")) + if config_json.get("MODEL_PROVIDERS_CFG_HOST"): + _config_factory.model_providers_cfg_host(config_json.get("MODEL_PROVIDERS_CFG_HOST")) + if config_json.get("MODEL_PROVIDERS_CFG_PORT"): + _config_factory.model_providers_cfg_port(config_json.get("MODEL_PROVIDERS_CFG_PORT")) + + return _config_factory + + @classmethod + def get_type(cls) -> str: + return ConfigModel.class_name() + + def get_config(self) -> ConfigModel: + return self._config_factory.get_config() + + def set_default_llm_model(self, llm_model: str): + self._config_factory.default_llm_model(llm_model) + self.store_config() + + def set_default_embedding_model(self, embedding_model: str): + self._config_factory.default_embedding_model(embedding_model) + self.store_config() + + def set_agent_model(self, agent_model: str): + self._config_factory.agent_model(agent_model) + self.store_config() + + def set_history_len(self, history_len: int): + self._config_factory.history_len(history_len) + self.store_config() + + def set_max_tokens(self, max_tokens: int): + self._config_factory.max_tokens(max_tokens) + self.store_config() + + def set_temperature(self, temperature: float): + self._config_factory.temperature(temperature) + self.store_config() + + def set_support_agent_models(self, support_agent_models: List[str]): + self._config_factory.support_agent_models(support_agent_models) + self.store_config() + + def set_model_providers_cfg_path_config(self, model_providers_cfg_path_config: str): + self._config_factory.model_providers_cfg_path_config(model_providers_cfg_path_config) + self.store_config() + + def set_model_providers_cfg_host(self, model_providers_cfg_host: str): + self._config_factory.model_providers_cfg_host(model_providers_cfg_host) + self.store_config() + + def set_model_providers_cfg_port(self, model_providers_cfg_port: int): + self._config_factory.model_providers_cfg_port(model_providers_cfg_port) + self.store_config() + + +config_model_workspace: ConfigModelWorkSpace = ConfigModelWorkSpace() \ No newline at end of file diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index 00e986d7..e05f00c6 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-chatchat" -version = "0.3.0.20240610.1" +version = "0.3.0.20240611" description = "" authors = ["chatchat"] readme = "README.md" diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py index c748b827..d540e5f6 100644 --- a/libs/chatchat-server/tests/unit_tests/config/test_config.py +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -1,6 +1,12 @@ from pathlib import Path -from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigBasicWorkSpace +from chatchat.configs import ( + ConfigBasicFactory, + ConfigBasic, + ConfigBasicWorkSpace, + ConfigModelWorkSpace, + ConfigModel +) import os @@ -36,3 +42,56 @@ def test_workspace_default(): assert LOG_FORMAT is not None assert LOG_PATH is not None assert MEDIA_PATH is not None + + +def test_config_model_workspace(): + + config_model_workspace: ConfigModelWorkSpace = ConfigModelWorkSpace() + + assert config_model_workspace.get_config() is not None + + config_model_workspace.set_default_llm_model(llm_model="glm4") + config_model_workspace.set_default_embedding_model(embedding_model="text1") + config_model_workspace.set_agent_model(agent_model="agent") + config_model_workspace.set_history_len(history_len=1) + config_model_workspace.set_max_tokens(max_tokens=1000) + config_model_workspace.set_temperature(temperature=0.1) + config_model_workspace.set_support_agent_models(support_agent_models=["glm4"]) + config_model_workspace.set_model_providers_cfg_path_config(model_providers_cfg_path_config="model_providers.yaml") + config_model_workspace.set_model_providers_cfg_host(model_providers_cfg_host="127.0.0.1") + config_model_workspace.set_model_providers_cfg_port(model_providers_cfg_port=8000) + + config: ConfigModel = config_model_workspace.get_config() + + assert config.DEFAULT_LLM_MODEL == "glm4" + assert config.DEFAULT_EMBEDDING_MODEL == "text1" + assert config.Agent_MODEL == "agent" + assert config.HISTORY_LEN == 1 + assert config.MAX_TOKENS == 1000 + assert config.TEMPERATURE == 0.1 + assert config.SUPPORT_AGENT_MODELS == ["glm4"] + assert config.MODEL_PROVIDERS_CFG_PATH_CONFIG == "model_providers.yaml" + assert config.MODEL_PROVIDERS_CFG_HOST == "127.0.0.1" + assert config.MODEL_PROVIDERS_CFG_PORT == 8000 + config_model_workspace.clear() + + +def test_model_config(): + from chatchat.configs import ( + DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, Agent_MODEL, HISTORY_LEN, MAX_TOKENS, TEMPERATURE, + SUPPORT_AGENT_MODELS, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, + TOOL_CONFIG, MODEL_PLATFORMS, LLM_MODEL_CONFIG + ) + assert DEFAULT_LLM_MODEL is not None + assert DEFAULT_EMBEDDING_MODEL is not None + assert Agent_MODEL is None + assert HISTORY_LEN is not None + assert MAX_TOKENS is None + assert TEMPERATURE is not None + assert SUPPORT_AGENT_MODELS is not None + assert MODEL_PROVIDERS_CFG_PATH_CONFIG is not None + assert MODEL_PROVIDERS_CFG_HOST is not None + assert MODEL_PROVIDERS_CFG_PORT is not None + assert TOOL_CONFIG is not None + assert MODEL_PLATFORMS is not None + assert LLM_MODEL_CONFIG is not None