mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-05 14:23:23 +08:00
配置中心服务信息单元测试
This commit is contained in:
parent
e7a5d6a528
commit
fc03142bbc
@ -507,44 +507,76 @@ def _import_prompt_templates() -> Any:
|
|||||||
return PROMPT_TEMPLATES
|
return PROMPT_TEMPLATES
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ConfigServer() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
ConfigServer = load_mod(basic_config_load.get("module"), "ConfigServer")
|
||||||
|
|
||||||
|
return ConfigServer
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ConfigServerFactory() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
ConfigServerFactory = load_mod(basic_config_load.get("module"), "ConfigServerFactory")
|
||||||
|
|
||||||
|
return ConfigServerFactory
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ConfigServerWorkSpace() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
ConfigServerWorkSpace = load_mod(basic_config_load.get("module"), "ConfigServerWorkSpace")
|
||||||
|
|
||||||
|
return ConfigServerWorkSpace
|
||||||
|
|
||||||
|
|
||||||
|
def _import_config_server_workspace() -> Any:
|
||||||
|
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||||
|
load_mod = server_config_load.get("load_mod")
|
||||||
|
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||||
|
|
||||||
|
return config_server_workspace
|
||||||
|
|
||||||
|
|
||||||
def _import_httpx_default_timeout() -> Any:
|
def _import_httpx_default_timeout() -> Any:
|
||||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||||
load_mod = server_config_load.get("load_mod")
|
load_mod = server_config_load.get("load_mod")
|
||||||
HTTPX_DEFAULT_TIMEOUT = load_mod(server_config_load.get("module"), "HTTPX_DEFAULT_TIMEOUT")
|
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||||
|
|
||||||
return HTTPX_DEFAULT_TIMEOUT
|
return config_server_workspace.get_config().HTTPX_DEFAULT_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
def _import_open_cross_domain() -> Any:
|
def _import_open_cross_domain() -> Any:
|
||||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||||
load_mod = server_config_load.get("load_mod")
|
load_mod = server_config_load.get("load_mod")
|
||||||
OPEN_CROSS_DOMAIN = load_mod(server_config_load.get("module"), "OPEN_CROSS_DOMAIN")
|
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||||
|
|
||||||
return OPEN_CROSS_DOMAIN
|
return config_server_workspace.get_config().OPEN_CROSS_DOMAIN
|
||||||
|
|
||||||
|
|
||||||
def _import_default_bind_host() -> Any:
|
def _import_default_bind_host() -> Any:
|
||||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||||
load_mod = server_config_load.get("load_mod")
|
load_mod = server_config_load.get("load_mod")
|
||||||
DEFAULT_BIND_HOST = load_mod(server_config_load.get("module"), "DEFAULT_BIND_HOST")
|
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||||
|
|
||||||
return DEFAULT_BIND_HOST
|
return config_server_workspace.get_config().DEFAULT_BIND_HOST
|
||||||
|
|
||||||
|
|
||||||
def _import_webui_server() -> Any:
|
def _import_webui_server() -> Any:
|
||||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||||
load_mod = server_config_load.get("load_mod")
|
load_mod = server_config_load.get("load_mod")
|
||||||
WEBUI_SERVER = load_mod(server_config_load.get("module"), "WEBUI_SERVER")
|
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||||
|
|
||||||
return WEBUI_SERVER
|
return config_server_workspace.get_config().WEBUI_SERVER
|
||||||
|
|
||||||
|
|
||||||
def _import_api_server() -> Any:
|
def _import_api_server() -> Any:
|
||||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||||
load_mod = server_config_load.get("load_mod")
|
load_mod = server_config_load.get("load_mod")
|
||||||
API_SERVER = load_mod(server_config_load.get("module"), "API_SERVER")
|
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||||
|
|
||||||
return API_SERVER
|
return config_server_workspace.get_config().API_SERVER
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> Any:
|
def __getattr__(name: str) -> Any:
|
||||||
@ -564,6 +596,14 @@ def __getattr__(name: str) -> Any:
|
|||||||
return _import_ConfigModelWorkSpace()
|
return _import_ConfigModelWorkSpace()
|
||||||
elif name == "config_model_workspace":
|
elif name == "config_model_workspace":
|
||||||
return _import_config_model_workspace()
|
return _import_config_model_workspace()
|
||||||
|
elif name == "ConfigServer":
|
||||||
|
return _import_ConfigServer()
|
||||||
|
elif name == "ConfigServerFactory":
|
||||||
|
return _import_ConfigServerFactory()
|
||||||
|
elif name == "ConfigServerWorkSpace":
|
||||||
|
return _import_ConfigServerWorkSpace()
|
||||||
|
elif name == "config_server_workspace":
|
||||||
|
return _import_config_server_workspace()
|
||||||
elif name == "log_verbose":
|
elif name == "log_verbose":
|
||||||
return _import_log_verbose()
|
return _import_log_verbose()
|
||||||
elif name == "CHATCHAT_ROOT":
|
elif name == "CHATCHAT_ROOT":
|
||||||
@ -724,4 +764,10 @@ __all__ = [
|
|||||||
|
|
||||||
"config_model_workspace",
|
"config_model_workspace",
|
||||||
|
|
||||||
|
"ConfigServer",
|
||||||
|
"ConfigServerFactory",
|
||||||
|
"ConfigServerWorkSpace",
|
||||||
|
|
||||||
|
"config_server_workspace",
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,25 +1,141 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional, Dict
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent))
|
||||||
|
import _core_config as core_config
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
class ConfigServer(core_config.Config):
|
||||||
HTTPX_DEFAULT_TIMEOUT = 300.0
|
HTTPX_DEFAULT_TIMEOUT: Optional[float] = None
|
||||||
|
"""httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。"""
|
||||||
|
OPEN_CROSS_DOMAIN: Optional[bool] = None
|
||||||
|
"""API 是否开启跨域,默认为False,如果需要开启,请设置为True"""
|
||||||
|
DEFAULT_BIND_HOST: Optional[str] = None
|
||||||
|
"""各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host"""
|
||||||
|
WEBUI_SERVER_PORT: Optional[int] = None
|
||||||
|
"""webui port"""
|
||||||
|
API_SERVER_PORT: Optional[int] = None
|
||||||
|
"""api port"""
|
||||||
|
WEBUI_SERVER: Optional[Dict[str, Any]] = None
|
||||||
|
"""webui.py server"""
|
||||||
|
API_SERVER: Optional[Dict[str, Any]] = None
|
||||||
|
"""api.py server"""
|
||||||
|
|
||||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
@classmethod
|
||||||
# is open cross domain
|
def class_name(cls) -> str:
|
||||||
OPEN_CROSS_DOMAIN = True
|
return cls.__name__
|
||||||
|
|
||||||
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
def __str__(self):
|
||||||
DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
|
return self.to_json()
|
||||||
|
|
||||||
|
|
||||||
# webui.py server
|
@dataclass
|
||||||
WEBUI_SERVER = {
|
class ConfigServerFactory(core_config.ConfigFactory[ConfigServer]):
|
||||||
"host": DEFAULT_BIND_HOST,
|
"""Server 配置工厂类"""
|
||||||
"port": 8501,
|
|
||||||
}
|
|
||||||
|
|
||||||
# api.py server
|
def __init__(self):
|
||||||
API_SERVER = {
|
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||||||
"host": DEFAULT_BIND_HOST,
|
self.HTTPX_DEFAULT_TIMEOUT = 300.0
|
||||||
"port": 7861,
|
|
||||||
}
|
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||||
|
# is open cross domain
|
||||||
|
self.OPEN_CROSS_DOMAIN = True
|
||||||
|
|
||||||
|
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||||||
|
self.DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
|
||||||
|
self.WEBUI_SERVER_PORT = 8501
|
||||||
|
self.API_SERVER_PORT = 7861
|
||||||
|
# webui.py server
|
||||||
|
self.WEBUI_SERVER = {
|
||||||
|
"host": self.DEFAULT_BIND_HOST,
|
||||||
|
"port": self.WEBUI_SERVER_PORT,
|
||||||
|
}
|
||||||
|
|
||||||
|
# api.py server
|
||||||
|
self.API_SERVER = {
|
||||||
|
"host": self.DEFAULT_BIND_HOST,
|
||||||
|
"port": self.API_SERVER_PORT,
|
||||||
|
}
|
||||||
|
|
||||||
|
def httpx_default_timeout(self, timeout: float):
|
||||||
|
self.HTTPX_DEFAULT_TIMEOUT = timeout
|
||||||
|
|
||||||
|
def open_cross_domain(self, open_cross_domain: bool):
|
||||||
|
self.OPEN_CROSS_DOMAIN = open_cross_domain
|
||||||
|
|
||||||
|
def default_bind_host(self, default_bind_host: str):
|
||||||
|
self.DEFAULT_BIND_HOST = default_bind_host
|
||||||
|
|
||||||
|
def webui_server_port(self, webui_server_port: int):
|
||||||
|
self.WEBUI_SERVER_PORT = webui_server_port
|
||||||
|
|
||||||
|
def api_server_port(self, api_server_port: int):
|
||||||
|
self.API_SERVER_PORT = api_server_port
|
||||||
|
|
||||||
|
def get_config(self) -> ConfigServer:
|
||||||
|
config = ConfigServer()
|
||||||
|
config.HTTPX_DEFAULT_TIMEOUT = self.HTTPX_DEFAULT_TIMEOUT
|
||||||
|
config.OPEN_CROSS_DOMAIN = self.OPEN_CROSS_DOMAIN
|
||||||
|
config.DEFAULT_BIND_HOST = self.DEFAULT_BIND_HOST
|
||||||
|
config.WEBUI_SERVER_PORT = self.WEBUI_SERVER_PORT
|
||||||
|
config.API_SERVER_PORT = self.API_SERVER_PORT
|
||||||
|
config.WEBUI_SERVER = self.WEBUI_SERVER
|
||||||
|
config.API_SERVER = self.API_SERVER
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigServerWorkSpace(core_config.ConfigWorkSpace[ConfigServerFactory, ConfigServer]):
|
||||||
|
"""
|
||||||
|
工作空间的配置预设,提供ConfigServer建造方法产生实例。
|
||||||
|
"""
|
||||||
|
config_factory_cls = ConfigServerFactory
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _build_config_factory(self, config_json: Any) -> ConfigServerFactory:
|
||||||
|
_config_factory = self.config_factory_cls()
|
||||||
|
if config_json.get("HTTPX_DEFAULT_TIMEOUT") is not None:
|
||||||
|
_config_factory.httpx_default_timeout(config_json["HTTPX_DEFAULT_TIMEOUT"])
|
||||||
|
if config_json.get("OPEN_CROSS_DOMAIN") is not None:
|
||||||
|
_config_factory.open_cross_domain(config_json["OPEN_CROSS_DOMAIN"])
|
||||||
|
if config_json.get("DEFAULT_BIND_HOST") is not None:
|
||||||
|
_config_factory.default_bind_host(config_json["DEFAULT_BIND_HOST"])
|
||||||
|
if config_json.get("WEBUI_SERVER_PORT") is not None:
|
||||||
|
_config_factory.webui_server_port(config_json["WEBUI_SERVER_PORT"])
|
||||||
|
if config_json.get("API_SERVER_PORT") is not None:
|
||||||
|
_config_factory.api_server_port(config_json["API_SERVER_PORT"])
|
||||||
|
|
||||||
|
return _config_factory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_type(cls) -> str:
|
||||||
|
return ConfigServer.class_name()
|
||||||
|
|
||||||
|
def get_config(self) -> ConfigServer:
|
||||||
|
return self._config_factory.get_config()
|
||||||
|
|
||||||
|
def set_httpx_default_timeout(self, timeout: float):
|
||||||
|
self._config_factory.httpx_default_timeout(timeout)
|
||||||
|
|
||||||
|
def set_open_cross_domain(self, open_cross_domain: bool):
|
||||||
|
self._config_factory.open_cross_domain(open_cross_domain)
|
||||||
|
|
||||||
|
def set_default_bind_host(self, default_bind_host: str):
|
||||||
|
self._config_factory.default_bind_host(default_bind_host)
|
||||||
|
|
||||||
|
def set_webui_server_port(self, webui_server_port: int):
|
||||||
|
self._config_factory.webui_server_port(webui_server_port)
|
||||||
|
|
||||||
|
def set_api_server_port(self, api_server_port: int):
|
||||||
|
self._config_factory.api_server_port(api_server_port)
|
||||||
|
|
||||||
|
|
||||||
|
config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace()
|
||||||
@ -5,7 +5,9 @@ from chatchat.configs import (
|
|||||||
ConfigBasic,
|
ConfigBasic,
|
||||||
ConfigBasicWorkSpace,
|
ConfigBasicWorkSpace,
|
||||||
ConfigModelWorkSpace,
|
ConfigModelWorkSpace,
|
||||||
ConfigModel
|
ConfigModel,
|
||||||
|
ConfigServerWorkSpace,
|
||||||
|
ConfigServer,
|
||||||
)
|
)
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -95,3 +97,24 @@ def test_model_config():
|
|||||||
assert TOOL_CONFIG is not None
|
assert TOOL_CONFIG is not None
|
||||||
assert MODEL_PLATFORMS is not None
|
assert MODEL_PLATFORMS is not None
|
||||||
assert LLM_MODEL_CONFIG is not None
|
assert LLM_MODEL_CONFIG is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_server_workspace():
|
||||||
|
config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace()
|
||||||
|
|
||||||
|
assert config_server_workspace.get_config() is not None
|
||||||
|
|
||||||
|
config_server_workspace.set_httpx_default_timeout(timeout=10)
|
||||||
|
config_server_workspace.set_open_cross_domain(open_cross_domain=True)
|
||||||
|
config_server_workspace.set_default_bind_host(default_bind_host="0.0.0.0")
|
||||||
|
config_server_workspace.set_webui_server_port(webui_server_port=8000)
|
||||||
|
config_server_workspace.set_api_server_port(api_server_port=8001)
|
||||||
|
|
||||||
|
config: ConfigServer = config_server_workspace.get_config()
|
||||||
|
|
||||||
|
assert config.HTTPX_DEFAULT_TIMEOUT == 10
|
||||||
|
assert config.OPEN_CROSS_DOMAIN is True
|
||||||
|
assert config.DEFAULT_BIND_HOST == "0.0.0.0"
|
||||||
|
assert config.WEBUI_SERVER_PORT == 8000
|
||||||
|
assert config.API_SERVER_PORT == 8001
|
||||||
|
config_server_workspace.clear()
|
||||||
Loading…
x
Reference in New Issue
Block a user