diff --git a/libs/chatchat-server/chatchat/config_work_space.py b/libs/chatchat-server/chatchat/config_work_space.py index c8689816..a7ea0c9b 100644 --- a/libs/chatchat-server/chatchat/config_work_space.py +++ b/libs/chatchat-server/chatchat/config_work_space.py @@ -85,5 +85,37 @@ def model(**kwargs): print(config_model_workspace.get_config()) +@main.command("server", help="服务配置") +@click.option("--httpx_default_timeout", type=int, help="httpx默认超时时间") +@click.option("--open_cross_domain", type=click.Choice(["true", "false"]), help="是否开启跨域") +@click.option("--default_bind_host", help="默认绑定host") +@click.option("--webui_server_port", type=int, help="webui服务端口") +@click.option("--api_server_port", type=int, help="api服务端口") +@click.option("--clear", is_flag=True, help="清除配置") +@click.option("--show", is_flag=True, help="显示配置") +def server(**kwargs): + + if kwargs["httpx_default_timeout"]: + config_basic_workspace.set_httpx_default_timeout(httpx_default_timeout=kwargs["httpx_default_timeout"]) + if kwargs["open_cross_domain"]: + if kwargs["open_cross_domain"].lower() == "true": + config_basic_workspace.set_open_cross_domain(True) + else: + config_basic_workspace.set_open_cross_domain(False) + if kwargs["default_bind_host"]: + config_basic_workspace.set_default_bind_host(default_bind_host=kwargs["default_bind_host"]) + + if kwargs["webui_server_port"]: + config_basic_workspace.set_webui_server_port(webui_server_port=kwargs["webui_server_port"]) + + if kwargs["api_server_port"]: + config_basic_workspace.set_api_server_port(api_server_port=kwargs["api_server_port"]) + + if kwargs["clear"]: + config_model_workspace.clear() + if kwargs["show"]: + print(config_model_workspace.get_config()) + + if __name__ == "__main__": main() diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index c04e5e33..c20d2e39 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -507,44 +507,76 @@ def _import_prompt_templates() -> Any: return PROMPT_TEMPLATES +def _import_ConfigServer() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_server_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigServer = load_mod(basic_config_load.get("module"), "ConfigServer") + + return ConfigServer + + +def _import_ConfigServerFactory() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_server_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigServerFactory = load_mod(basic_config_load.get("module"), "ConfigServerFactory") + + return ConfigServerFactory + + +def _import_ConfigServerWorkSpace() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_server_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigServerWorkSpace = load_mod(basic_config_load.get("module"), "ConfigServerWorkSpace") + + return ConfigServerWorkSpace + + +def _import_config_server_workspace() -> Any: + server_config_load = CONFIG_IMPORTS.get("_server_config.py") + load_mod = server_config_load.get("load_mod") + config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") + + return config_server_workspace + + def _import_httpx_default_timeout() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - HTTPX_DEFAULT_TIMEOUT = load_mod(server_config_load.get("module"), "HTTPX_DEFAULT_TIMEOUT") + config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") - return HTTPX_DEFAULT_TIMEOUT + return config_server_workspace.get_config().HTTPX_DEFAULT_TIMEOUT def _import_open_cross_domain() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - OPEN_CROSS_DOMAIN = load_mod(server_config_load.get("module"), "OPEN_CROSS_DOMAIN") + config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") - return OPEN_CROSS_DOMAIN + return config_server_workspace.get_config().OPEN_CROSS_DOMAIN def _import_default_bind_host() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - DEFAULT_BIND_HOST = load_mod(server_config_load.get("module"), "DEFAULT_BIND_HOST") + config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") - return DEFAULT_BIND_HOST + return config_server_workspace.get_config().DEFAULT_BIND_HOST def _import_webui_server() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - WEBUI_SERVER = load_mod(server_config_load.get("module"), "WEBUI_SERVER") + config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") - return WEBUI_SERVER + return config_server_workspace.get_config().WEBUI_SERVER def _import_api_server() -> Any: server_config_load = CONFIG_IMPORTS.get("_server_config.py") load_mod = server_config_load.get("load_mod") - API_SERVER = load_mod(server_config_load.get("module"), "API_SERVER") + config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace") - return API_SERVER + return config_server_workspace.get_config().API_SERVER def __getattr__(name: str) -> Any: @@ -564,6 +596,14 @@ def __getattr__(name: str) -> Any: return _import_ConfigModelWorkSpace() elif name == "config_model_workspace": return _import_config_model_workspace() + elif name == "ConfigServer": + return _import_ConfigServer() + elif name == "ConfigServerFactory": + return _import_ConfigServerFactory() + elif name == "ConfigServerWorkSpace": + return _import_ConfigServerWorkSpace() + elif name == "config_server_workspace": + return _import_config_server_workspace() elif name == "log_verbose": return _import_log_verbose() elif name == "CHATCHAT_ROOT": @@ -724,4 +764,10 @@ __all__ = [ "config_model_workspace", + "ConfigServer", + "ConfigServerFactory", + "ConfigServerWorkSpace", + + "config_server_workspace", + ] diff --git a/libs/chatchat-server/chatchat/configs/_server_config.py b/libs/chatchat-server/chatchat/configs/_server_config.py index 40485250..d07da988 100644 --- a/libs/chatchat-server/chatchat/configs/_server_config.py +++ b/libs/chatchat-server/chatchat/configs/_server_config.py @@ -1,25 +1,141 @@ +import os +import json +from dataclasses import dataclass +from pathlib import Path import sys +import logging +from typing import Any, Optional, Dict + +sys.path.append(str(Path(__file__).parent)) +import _core_config as core_config + +logger = logging.getLogger() -# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。 -HTTPX_DEFAULT_TIMEOUT = 300.0 +class ConfigServer(core_config.Config): + HTTPX_DEFAULT_TIMEOUT: Optional[float] = None + """httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。""" + OPEN_CROSS_DOMAIN: Optional[bool] = None + """API 是否开启跨域,默认为False,如果需要开启,请设置为True""" + DEFAULT_BIND_HOST: Optional[str] = None + """各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host""" + WEBUI_SERVER_PORT: Optional[int] = None + """webui port""" + API_SERVER_PORT: Optional[int] = None + """api port""" + WEBUI_SERVER: Optional[Dict[str, Any]] = None + """webui.py server""" + API_SERVER: Optional[Dict[str, Any]] = None + """api.py server""" -# API 是否开启跨域,默认为False,如果需要开启,请设置为True -# is open cross domain -OPEN_CROSS_DOMAIN = True + @classmethod + def class_name(cls) -> str: + return cls.__name__ -# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host -DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1" + def __str__(self): + return self.to_json() -# webui.py server -WEBUI_SERVER = { - "host": DEFAULT_BIND_HOST, - "port": 8501, -} +@dataclass +class ConfigServerFactory(core_config.ConfigFactory[ConfigServer]): + """Server 配置工厂类""" -# api.py server -API_SERVER = { - "host": DEFAULT_BIND_HOST, - "port": 7861, -} + def __init__(self): + # httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。 + self.HTTPX_DEFAULT_TIMEOUT = 300.0 + + # API 是否开启跨域,默认为False,如果需要开启,请设置为True + # is open cross domain + self.OPEN_CROSS_DOMAIN = True + + # 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host + self.DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1" + self.WEBUI_SERVER_PORT = 8501 + self.API_SERVER_PORT = 7861 + # webui.py server + self.WEBUI_SERVER = { + "host": self.DEFAULT_BIND_HOST, + "port": self.WEBUI_SERVER_PORT, + } + + # api.py server + self.API_SERVER = { + "host": self.DEFAULT_BIND_HOST, + "port": self.API_SERVER_PORT, + } + + def httpx_default_timeout(self, timeout: float): + self.HTTPX_DEFAULT_TIMEOUT = timeout + + def open_cross_domain(self, open_cross_domain: bool): + self.OPEN_CROSS_DOMAIN = open_cross_domain + + def default_bind_host(self, default_bind_host: str): + self.DEFAULT_BIND_HOST = default_bind_host + + def webui_server_port(self, webui_server_port: int): + self.WEBUI_SERVER_PORT = webui_server_port + + def api_server_port(self, api_server_port: int): + self.API_SERVER_PORT = api_server_port + + def get_config(self) -> ConfigServer: + config = ConfigServer() + config.HTTPX_DEFAULT_TIMEOUT = self.HTTPX_DEFAULT_TIMEOUT + config.OPEN_CROSS_DOMAIN = self.OPEN_CROSS_DOMAIN + config.DEFAULT_BIND_HOST = self.DEFAULT_BIND_HOST + config.WEBUI_SERVER_PORT = self.WEBUI_SERVER_PORT + config.API_SERVER_PORT = self.API_SERVER_PORT + config.WEBUI_SERVER = self.WEBUI_SERVER + config.API_SERVER = self.API_SERVER + return config + + +class ConfigServerWorkSpace(core_config.ConfigWorkSpace[ConfigServerFactory, ConfigServer]): + """ + 工作空间的配置预设,提供ConfigServer建造方法产生实例。 + """ + config_factory_cls = ConfigServerFactory + + def __init__(self): + super().__init__() + + def _build_config_factory(self, config_json: Any) -> ConfigServerFactory: + _config_factory = self.config_factory_cls() + if config_json.get("HTTPX_DEFAULT_TIMEOUT") is not None: + _config_factory.httpx_default_timeout(config_json["HTTPX_DEFAULT_TIMEOUT"]) + if config_json.get("OPEN_CROSS_DOMAIN") is not None: + _config_factory.open_cross_domain(config_json["OPEN_CROSS_DOMAIN"]) + if config_json.get("DEFAULT_BIND_HOST") is not None: + _config_factory.default_bind_host(config_json["DEFAULT_BIND_HOST"]) + if config_json.get("WEBUI_SERVER_PORT") is not None: + _config_factory.webui_server_port(config_json["WEBUI_SERVER_PORT"]) + if config_json.get("API_SERVER_PORT") is not None: + _config_factory.api_server_port(config_json["API_SERVER_PORT"]) + + return _config_factory + + @classmethod + def get_type(cls) -> str: + return ConfigServer.class_name() + + def get_config(self) -> ConfigServer: + return self._config_factory.get_config() + + def set_httpx_default_timeout(self, timeout: float): + self._config_factory.httpx_default_timeout(timeout) + + def set_open_cross_domain(self, open_cross_domain: bool): + self._config_factory.open_cross_domain(open_cross_domain) + + def set_default_bind_host(self, default_bind_host: str): + self._config_factory.default_bind_host(default_bind_host) + + def set_webui_server_port(self, webui_server_port: int): + self._config_factory.webui_server_port(webui_server_port) + + def set_api_server_port(self, api_server_port: int): + self._config_factory.api_server_port(api_server_port) + + +config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace() \ No newline at end of file diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py index d540e5f6..35d79f2d 100644 --- a/libs/chatchat-server/tests/unit_tests/config/test_config.py +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -5,7 +5,9 @@ from chatchat.configs import ( ConfigBasic, ConfigBasicWorkSpace, ConfigModelWorkSpace, - ConfigModel + ConfigModel, + ConfigServerWorkSpace, + ConfigServer, ) import os @@ -95,3 +97,24 @@ def test_model_config(): assert TOOL_CONFIG is not None assert MODEL_PLATFORMS is not None assert LLM_MODEL_CONFIG is not None + + +def test_config_server_workspace(): + config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace() + + assert config_server_workspace.get_config() is not None + + config_server_workspace.set_httpx_default_timeout(timeout=10) + config_server_workspace.set_open_cross_domain(open_cross_domain=True) + config_server_workspace.set_default_bind_host(default_bind_host="0.0.0.0") + config_server_workspace.set_webui_server_port(webui_server_port=8000) + config_server_workspace.set_api_server_port(api_server_port=8001) + + config: ConfigServer = config_server_workspace.get_config() + + assert config.HTTPX_DEFAULT_TIMEOUT == 10 + assert config.OPEN_CROSS_DOMAIN is True + assert config.DEFAULT_BIND_HOST == "0.0.0.0" + assert config.WEBUI_SERVER_PORT == 8000 + assert config.API_SERVER_PORT == 8001 + config_server_workspace.clear() \ No newline at end of file