配置中心服务信息 子命令入口 (#4166)

* 配置中心服务信息单元测试

* 配置中心服务信息 子命令入口
This commit is contained in:
glide-the 2024-06-11 16:49:16 +08:00 committed by GitHub
parent e7a5d6a528
commit 35c2f596f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 245 additions and 28 deletions

View File

@ -85,5 +85,37 @@ def model(**kwargs):
print(config_model_workspace.get_config())
@main.command("server", help="服务配置")
@click.option("--httpx_default_timeout", type=int, help="httpx默认超时时间")
@click.option("--open_cross_domain", type=click.Choice(["true", "false"]), help="是否开启跨域")
@click.option("--default_bind_host", help="默认绑定host")
@click.option("--webui_server_port", type=int, help="webui服务端口")
@click.option("--api_server_port", type=int, help="api服务端口")
@click.option("--clear", is_flag=True, help="清除配置")
@click.option("--show", is_flag=True, help="显示配置")
def server(**kwargs):
if kwargs["httpx_default_timeout"]:
config_basic_workspace.set_httpx_default_timeout(httpx_default_timeout=kwargs["httpx_default_timeout"])
if kwargs["open_cross_domain"]:
if kwargs["open_cross_domain"].lower() == "true":
config_basic_workspace.set_open_cross_domain(True)
else:
config_basic_workspace.set_open_cross_domain(False)
if kwargs["default_bind_host"]:
config_basic_workspace.set_default_bind_host(default_bind_host=kwargs["default_bind_host"])
if kwargs["webui_server_port"]:
config_basic_workspace.set_webui_server_port(webui_server_port=kwargs["webui_server_port"])
if kwargs["api_server_port"]:
config_basic_workspace.set_api_server_port(api_server_port=kwargs["api_server_port"])
if kwargs["clear"]:
config_model_workspace.clear()
if kwargs["show"]:
print(config_model_workspace.get_config())
if __name__ == "__main__":
main()

View File

@ -507,44 +507,76 @@ def _import_prompt_templates() -> Any:
return PROMPT_TEMPLATES
def _import_ConfigServer() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigServer = load_mod(basic_config_load.get("module"), "ConfigServer")
return ConfigServer
def _import_ConfigServerFactory() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigServerFactory = load_mod(basic_config_load.get("module"), "ConfigServerFactory")
return ConfigServerFactory
def _import_ConfigServerWorkSpace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigServerWorkSpace = load_mod(basic_config_load.get("module"), "ConfigServerWorkSpace")
return ConfigServerWorkSpace
def _import_config_server_workspace() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod")
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
return config_server_workspace
def _import_httpx_default_timeout() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod")
HTTPX_DEFAULT_TIMEOUT = load_mod(server_config_load.get("module"), "HTTPX_DEFAULT_TIMEOUT")
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
return HTTPX_DEFAULT_TIMEOUT
return config_server_workspace.get_config().HTTPX_DEFAULT_TIMEOUT
def _import_open_cross_domain() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod")
OPEN_CROSS_DOMAIN = load_mod(server_config_load.get("module"), "OPEN_CROSS_DOMAIN")
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
return OPEN_CROSS_DOMAIN
return config_server_workspace.get_config().OPEN_CROSS_DOMAIN
def _import_default_bind_host() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod")
DEFAULT_BIND_HOST = load_mod(server_config_load.get("module"), "DEFAULT_BIND_HOST")
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
return DEFAULT_BIND_HOST
return config_server_workspace.get_config().DEFAULT_BIND_HOST
def _import_webui_server() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod")
WEBUI_SERVER = load_mod(server_config_load.get("module"), "WEBUI_SERVER")
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
return WEBUI_SERVER
return config_server_workspace.get_config().WEBUI_SERVER
def _import_api_server() -> Any:
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
load_mod = server_config_load.get("load_mod")
API_SERVER = load_mod(server_config_load.get("module"), "API_SERVER")
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
return API_SERVER
return config_server_workspace.get_config().API_SERVER
def __getattr__(name: str) -> Any:
@ -564,6 +596,14 @@ def __getattr__(name: str) -> Any:
return _import_ConfigModelWorkSpace()
elif name == "config_model_workspace":
return _import_config_model_workspace()
elif name == "ConfigServer":
return _import_ConfigServer()
elif name == "ConfigServerFactory":
return _import_ConfigServerFactory()
elif name == "ConfigServerWorkSpace":
return _import_ConfigServerWorkSpace()
elif name == "config_server_workspace":
return _import_config_server_workspace()
elif name == "log_verbose":
return _import_log_verbose()
elif name == "CHATCHAT_ROOT":
@ -724,4 +764,10 @@ __all__ = [
"config_model_workspace",
"ConfigServer",
"ConfigServerFactory",
"ConfigServerWorkSpace",
"config_server_workspace",
]

View File

@ -1,25 +1,141 @@
import os
import json
from dataclasses import dataclass
from pathlib import Path
import sys
import logging
from typing import Any, Optional, Dict
sys.path.append(str(Path(__file__).parent))
import _core_config as core_config
logger = logging.getLogger()
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
HTTPX_DEFAULT_TIMEOUT = 300.0
class ConfigServer(core_config.Config):
HTTPX_DEFAULT_TIMEOUT: Optional[float] = None
"""httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。"""
OPEN_CROSS_DOMAIN: Optional[bool] = None
"""API 是否开启跨域默认为False如果需要开启请设置为True"""
DEFAULT_BIND_HOST: Optional[str] = None
"""各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host"""
WEBUI_SERVER_PORT: Optional[int] = None
"""webui port"""
API_SERVER_PORT: Optional[int] = None
"""api port"""
WEBUI_SERVER: Optional[Dict[str, Any]] = None
"""webui.py server"""
API_SERVER: Optional[Dict[str, Any]] = None
"""api.py server"""
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain
OPEN_CROSS_DOMAIN = True
@classmethod
def class_name(cls) -> str:
return cls.__name__
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
def __str__(self):
return self.to_json()
# webui.py server
WEBUI_SERVER = {
"host": DEFAULT_BIND_HOST,
"port": 8501,
}
@dataclass
class ConfigServerFactory(core_config.ConfigFactory[ConfigServer]):
"""Server 配置工厂类"""
# api.py server
API_SERVER = {
"host": DEFAULT_BIND_HOST,
"port": 7861,
}
def __init__(self):
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
self.HTTPX_DEFAULT_TIMEOUT = 300.0
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain
self.OPEN_CROSS_DOMAIN = True
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
self.DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
self.WEBUI_SERVER_PORT = 8501
self.API_SERVER_PORT = 7861
# webui.py server
self.WEBUI_SERVER = {
"host": self.DEFAULT_BIND_HOST,
"port": self.WEBUI_SERVER_PORT,
}
# api.py server
self.API_SERVER = {
"host": self.DEFAULT_BIND_HOST,
"port": self.API_SERVER_PORT,
}
def httpx_default_timeout(self, timeout: float):
self.HTTPX_DEFAULT_TIMEOUT = timeout
def open_cross_domain(self, open_cross_domain: bool):
self.OPEN_CROSS_DOMAIN = open_cross_domain
def default_bind_host(self, default_bind_host: str):
self.DEFAULT_BIND_HOST = default_bind_host
def webui_server_port(self, webui_server_port: int):
self.WEBUI_SERVER_PORT = webui_server_port
def api_server_port(self, api_server_port: int):
self.API_SERVER_PORT = api_server_port
def get_config(self) -> ConfigServer:
config = ConfigServer()
config.HTTPX_DEFAULT_TIMEOUT = self.HTTPX_DEFAULT_TIMEOUT
config.OPEN_CROSS_DOMAIN = self.OPEN_CROSS_DOMAIN
config.DEFAULT_BIND_HOST = self.DEFAULT_BIND_HOST
config.WEBUI_SERVER_PORT = self.WEBUI_SERVER_PORT
config.API_SERVER_PORT = self.API_SERVER_PORT
config.WEBUI_SERVER = self.WEBUI_SERVER
config.API_SERVER = self.API_SERVER
return config
class ConfigServerWorkSpace(core_config.ConfigWorkSpace[ConfigServerFactory, ConfigServer]):
"""
工作空间的配置预设提供ConfigServer建造方法产生实例
"""
config_factory_cls = ConfigServerFactory
def __init__(self):
super().__init__()
def _build_config_factory(self, config_json: Any) -> ConfigServerFactory:
_config_factory = self.config_factory_cls()
if config_json.get("HTTPX_DEFAULT_TIMEOUT") is not None:
_config_factory.httpx_default_timeout(config_json["HTTPX_DEFAULT_TIMEOUT"])
if config_json.get("OPEN_CROSS_DOMAIN") is not None:
_config_factory.open_cross_domain(config_json["OPEN_CROSS_DOMAIN"])
if config_json.get("DEFAULT_BIND_HOST") is not None:
_config_factory.default_bind_host(config_json["DEFAULT_BIND_HOST"])
if config_json.get("WEBUI_SERVER_PORT") is not None:
_config_factory.webui_server_port(config_json["WEBUI_SERVER_PORT"])
if config_json.get("API_SERVER_PORT") is not None:
_config_factory.api_server_port(config_json["API_SERVER_PORT"])
return _config_factory
@classmethod
def get_type(cls) -> str:
return ConfigServer.class_name()
def get_config(self) -> ConfigServer:
return self._config_factory.get_config()
def set_httpx_default_timeout(self, timeout: float):
self._config_factory.httpx_default_timeout(timeout)
def set_open_cross_domain(self, open_cross_domain: bool):
self._config_factory.open_cross_domain(open_cross_domain)
def set_default_bind_host(self, default_bind_host: str):
self._config_factory.default_bind_host(default_bind_host)
def set_webui_server_port(self, webui_server_port: int):
self._config_factory.webui_server_port(webui_server_port)
def set_api_server_port(self, api_server_port: int):
self._config_factory.api_server_port(api_server_port)
config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace()

View File

@ -5,7 +5,9 @@ from chatchat.configs import (
ConfigBasic,
ConfigBasicWorkSpace,
ConfigModelWorkSpace,
ConfigModel
ConfigModel,
ConfigServerWorkSpace,
ConfigServer,
)
import os
@ -95,3 +97,24 @@ def test_model_config():
assert TOOL_CONFIG is not None
assert MODEL_PLATFORMS is not None
assert LLM_MODEL_CONFIG is not None
def test_config_server_workspace():
config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace()
assert config_server_workspace.get_config() is not None
config_server_workspace.set_httpx_default_timeout(timeout=10)
config_server_workspace.set_open_cross_domain(open_cross_domain=True)
config_server_workspace.set_default_bind_host(default_bind_host="0.0.0.0")
config_server_workspace.set_webui_server_port(webui_server_port=8000)
config_server_workspace.set_api_server_port(api_server_port=8001)
config: ConfigServer = config_server_workspace.get_config()
assert config.HTTPX_DEFAULT_TIMEOUT == 10
assert config.OPEN_CROSS_DOMAIN is True
assert config.DEFAULT_BIND_HOST == "0.0.0.0"
assert config.WEBUI_SERVER_PORT == 8000
assert config.API_SERVER_PORT == 8001
config_server_workspace.clear()