mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
parent
e7a5d6a528
commit
35c2f596f6
@ -85,5 +85,37 @@ def model(**kwargs):
|
||||
print(config_model_workspace.get_config())
|
||||
|
||||
|
||||
@main.command("server", help="服务配置")
|
||||
@click.option("--httpx_default_timeout", type=int, help="httpx默认超时时间")
|
||||
@click.option("--open_cross_domain", type=click.Choice(["true", "false"]), help="是否开启跨域")
|
||||
@click.option("--default_bind_host", help="默认绑定host")
|
||||
@click.option("--webui_server_port", type=int, help="webui服务端口")
|
||||
@click.option("--api_server_port", type=int, help="api服务端口")
|
||||
@click.option("--clear", is_flag=True, help="清除配置")
|
||||
@click.option("--show", is_flag=True, help="显示配置")
|
||||
def server(**kwargs):
|
||||
|
||||
if kwargs["httpx_default_timeout"]:
|
||||
config_basic_workspace.set_httpx_default_timeout(httpx_default_timeout=kwargs["httpx_default_timeout"])
|
||||
if kwargs["open_cross_domain"]:
|
||||
if kwargs["open_cross_domain"].lower() == "true":
|
||||
config_basic_workspace.set_open_cross_domain(True)
|
||||
else:
|
||||
config_basic_workspace.set_open_cross_domain(False)
|
||||
if kwargs["default_bind_host"]:
|
||||
config_basic_workspace.set_default_bind_host(default_bind_host=kwargs["default_bind_host"])
|
||||
|
||||
if kwargs["webui_server_port"]:
|
||||
config_basic_workspace.set_webui_server_port(webui_server_port=kwargs["webui_server_port"])
|
||||
|
||||
if kwargs["api_server_port"]:
|
||||
config_basic_workspace.set_api_server_port(api_server_port=kwargs["api_server_port"])
|
||||
|
||||
if kwargs["clear"]:
|
||||
config_model_workspace.clear()
|
||||
if kwargs["show"]:
|
||||
print(config_model_workspace.get_config())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -507,44 +507,76 @@ def _import_prompt_templates() -> Any:
|
||||
return PROMPT_TEMPLATES
|
||||
|
||||
|
||||
def _import_ConfigServer() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
ConfigServer = load_mod(basic_config_load.get("module"), "ConfigServer")
|
||||
|
||||
return ConfigServer
|
||||
|
||||
|
||||
def _import_ConfigServerFactory() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
ConfigServerFactory = load_mod(basic_config_load.get("module"), "ConfigServerFactory")
|
||||
|
||||
return ConfigServerFactory
|
||||
|
||||
|
||||
def _import_ConfigServerWorkSpace() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
ConfigServerWorkSpace = load_mod(basic_config_load.get("module"), "ConfigServerWorkSpace")
|
||||
|
||||
return ConfigServerWorkSpace
|
||||
|
||||
|
||||
def _import_config_server_workspace() -> Any:
|
||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||
load_mod = server_config_load.get("load_mod")
|
||||
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||
|
||||
return config_server_workspace
|
||||
|
||||
|
||||
def _import_httpx_default_timeout() -> Any:
|
||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||
load_mod = server_config_load.get("load_mod")
|
||||
HTTPX_DEFAULT_TIMEOUT = load_mod(server_config_load.get("module"), "HTTPX_DEFAULT_TIMEOUT")
|
||||
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||
|
||||
return HTTPX_DEFAULT_TIMEOUT
|
||||
return config_server_workspace.get_config().HTTPX_DEFAULT_TIMEOUT
|
||||
|
||||
|
||||
def _import_open_cross_domain() -> Any:
|
||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||
load_mod = server_config_load.get("load_mod")
|
||||
OPEN_CROSS_DOMAIN = load_mod(server_config_load.get("module"), "OPEN_CROSS_DOMAIN")
|
||||
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||
|
||||
return OPEN_CROSS_DOMAIN
|
||||
return config_server_workspace.get_config().OPEN_CROSS_DOMAIN
|
||||
|
||||
|
||||
def _import_default_bind_host() -> Any:
|
||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||
load_mod = server_config_load.get("load_mod")
|
||||
DEFAULT_BIND_HOST = load_mod(server_config_load.get("module"), "DEFAULT_BIND_HOST")
|
||||
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||
|
||||
return DEFAULT_BIND_HOST
|
||||
return config_server_workspace.get_config().DEFAULT_BIND_HOST
|
||||
|
||||
|
||||
def _import_webui_server() -> Any:
|
||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||
load_mod = server_config_load.get("load_mod")
|
||||
WEBUI_SERVER = load_mod(server_config_load.get("module"), "WEBUI_SERVER")
|
||||
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||
|
||||
return WEBUI_SERVER
|
||||
return config_server_workspace.get_config().WEBUI_SERVER
|
||||
|
||||
|
||||
def _import_api_server() -> Any:
|
||||
server_config_load = CONFIG_IMPORTS.get("_server_config.py")
|
||||
load_mod = server_config_load.get("load_mod")
|
||||
API_SERVER = load_mod(server_config_load.get("module"), "API_SERVER")
|
||||
config_server_workspace = load_mod(server_config_load.get("module"), "config_server_workspace")
|
||||
|
||||
return API_SERVER
|
||||
return config_server_workspace.get_config().API_SERVER
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
@ -564,6 +596,14 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_ConfigModelWorkSpace()
|
||||
elif name == "config_model_workspace":
|
||||
return _import_config_model_workspace()
|
||||
elif name == "ConfigServer":
|
||||
return _import_ConfigServer()
|
||||
elif name == "ConfigServerFactory":
|
||||
return _import_ConfigServerFactory()
|
||||
elif name == "ConfigServerWorkSpace":
|
||||
return _import_ConfigServerWorkSpace()
|
||||
elif name == "config_server_workspace":
|
||||
return _import_config_server_workspace()
|
||||
elif name == "log_verbose":
|
||||
return _import_log_verbose()
|
||||
elif name == "CHATCHAT_ROOT":
|
||||
@ -724,4 +764,10 @@ __all__ = [
|
||||
|
||||
"config_model_workspace",
|
||||
|
||||
"ConfigServer",
|
||||
"ConfigServerFactory",
|
||||
"ConfigServerWorkSpace",
|
||||
|
||||
"config_server_workspace",
|
||||
|
||||
]
|
||||
|
||||
@ -1,25 +1,141 @@
|
||||
import os
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import logging
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
import _core_config as core_config
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||||
HTTPX_DEFAULT_TIMEOUT = 300.0
|
||||
class ConfigServer(core_config.Config):
|
||||
HTTPX_DEFAULT_TIMEOUT: Optional[float] = None
|
||||
"""httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。"""
|
||||
OPEN_CROSS_DOMAIN: Optional[bool] = None
|
||||
"""API 是否开启跨域,默认为False,如果需要开启,请设置为True"""
|
||||
DEFAULT_BIND_HOST: Optional[str] = None
|
||||
"""各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host"""
|
||||
WEBUI_SERVER_PORT: Optional[int] = None
|
||||
"""webui port"""
|
||||
API_SERVER_PORT: Optional[int] = None
|
||||
"""api port"""
|
||||
WEBUI_SERVER: Optional[Dict[str, Any]] = None
|
||||
"""webui.py server"""
|
||||
API_SERVER: Optional[Dict[str, Any]] = None
|
||||
"""api.py server"""
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = True
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return cls.__name__
|
||||
|
||||
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||||
DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
|
||||
def __str__(self):
|
||||
return self.to_json()
|
||||
|
||||
|
||||
# webui.py server
|
||||
WEBUI_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8501,
|
||||
}
|
||||
@dataclass
|
||||
class ConfigServerFactory(core_config.ConfigFactory[ConfigServer]):
|
||||
"""Server 配置工厂类"""
|
||||
|
||||
# api.py server
|
||||
API_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 7861,
|
||||
}
|
||||
def __init__(self):
|
||||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||||
self.HTTPX_DEFAULT_TIMEOUT = 300.0
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
self.OPEN_CROSS_DOMAIN = True
|
||||
|
||||
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||||
self.DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
|
||||
self.WEBUI_SERVER_PORT = 8501
|
||||
self.API_SERVER_PORT = 7861
|
||||
# webui.py server
|
||||
self.WEBUI_SERVER = {
|
||||
"host": self.DEFAULT_BIND_HOST,
|
||||
"port": self.WEBUI_SERVER_PORT,
|
||||
}
|
||||
|
||||
# api.py server
|
||||
self.API_SERVER = {
|
||||
"host": self.DEFAULT_BIND_HOST,
|
||||
"port": self.API_SERVER_PORT,
|
||||
}
|
||||
|
||||
def httpx_default_timeout(self, timeout: float):
|
||||
self.HTTPX_DEFAULT_TIMEOUT = timeout
|
||||
|
||||
def open_cross_domain(self, open_cross_domain: bool):
|
||||
self.OPEN_CROSS_DOMAIN = open_cross_domain
|
||||
|
||||
def default_bind_host(self, default_bind_host: str):
|
||||
self.DEFAULT_BIND_HOST = default_bind_host
|
||||
|
||||
def webui_server_port(self, webui_server_port: int):
|
||||
self.WEBUI_SERVER_PORT = webui_server_port
|
||||
|
||||
def api_server_port(self, api_server_port: int):
|
||||
self.API_SERVER_PORT = api_server_port
|
||||
|
||||
def get_config(self) -> ConfigServer:
|
||||
config = ConfigServer()
|
||||
config.HTTPX_DEFAULT_TIMEOUT = self.HTTPX_DEFAULT_TIMEOUT
|
||||
config.OPEN_CROSS_DOMAIN = self.OPEN_CROSS_DOMAIN
|
||||
config.DEFAULT_BIND_HOST = self.DEFAULT_BIND_HOST
|
||||
config.WEBUI_SERVER_PORT = self.WEBUI_SERVER_PORT
|
||||
config.API_SERVER_PORT = self.API_SERVER_PORT
|
||||
config.WEBUI_SERVER = self.WEBUI_SERVER
|
||||
config.API_SERVER = self.API_SERVER
|
||||
return config
|
||||
|
||||
|
||||
class ConfigServerWorkSpace(core_config.ConfigWorkSpace[ConfigServerFactory, ConfigServer]):
|
||||
"""
|
||||
工作空间的配置预设,提供ConfigServer建造方法产生实例。
|
||||
"""
|
||||
config_factory_cls = ConfigServerFactory
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def _build_config_factory(self, config_json: Any) -> ConfigServerFactory:
|
||||
_config_factory = self.config_factory_cls()
|
||||
if config_json.get("HTTPX_DEFAULT_TIMEOUT") is not None:
|
||||
_config_factory.httpx_default_timeout(config_json["HTTPX_DEFAULT_TIMEOUT"])
|
||||
if config_json.get("OPEN_CROSS_DOMAIN") is not None:
|
||||
_config_factory.open_cross_domain(config_json["OPEN_CROSS_DOMAIN"])
|
||||
if config_json.get("DEFAULT_BIND_HOST") is not None:
|
||||
_config_factory.default_bind_host(config_json["DEFAULT_BIND_HOST"])
|
||||
if config_json.get("WEBUI_SERVER_PORT") is not None:
|
||||
_config_factory.webui_server_port(config_json["WEBUI_SERVER_PORT"])
|
||||
if config_json.get("API_SERVER_PORT") is not None:
|
||||
_config_factory.api_server_port(config_json["API_SERVER_PORT"])
|
||||
|
||||
return _config_factory
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
return ConfigServer.class_name()
|
||||
|
||||
def get_config(self) -> ConfigServer:
|
||||
return self._config_factory.get_config()
|
||||
|
||||
def set_httpx_default_timeout(self, timeout: float):
|
||||
self._config_factory.httpx_default_timeout(timeout)
|
||||
|
||||
def set_open_cross_domain(self, open_cross_domain: bool):
|
||||
self._config_factory.open_cross_domain(open_cross_domain)
|
||||
|
||||
def set_default_bind_host(self, default_bind_host: str):
|
||||
self._config_factory.default_bind_host(default_bind_host)
|
||||
|
||||
def set_webui_server_port(self, webui_server_port: int):
|
||||
self._config_factory.webui_server_port(webui_server_port)
|
||||
|
||||
def set_api_server_port(self, api_server_port: int):
|
||||
self._config_factory.api_server_port(api_server_port)
|
||||
|
||||
|
||||
config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace()
|
||||
@ -5,7 +5,9 @@ from chatchat.configs import (
|
||||
ConfigBasic,
|
||||
ConfigBasicWorkSpace,
|
||||
ConfigModelWorkSpace,
|
||||
ConfigModel
|
||||
ConfigModel,
|
||||
ConfigServerWorkSpace,
|
||||
ConfigServer,
|
||||
)
|
||||
import os
|
||||
|
||||
@ -95,3 +97,24 @@ def test_model_config():
|
||||
assert TOOL_CONFIG is not None
|
||||
assert MODEL_PLATFORMS is not None
|
||||
assert LLM_MODEL_CONFIG is not None
|
||||
|
||||
|
||||
def test_config_server_workspace():
|
||||
config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace()
|
||||
|
||||
assert config_server_workspace.get_config() is not None
|
||||
|
||||
config_server_workspace.set_httpx_default_timeout(timeout=10)
|
||||
config_server_workspace.set_open_cross_domain(open_cross_domain=True)
|
||||
config_server_workspace.set_default_bind_host(default_bind_host="0.0.0.0")
|
||||
config_server_workspace.set_webui_server_port(webui_server_port=8000)
|
||||
config_server_workspace.set_api_server_port(api_server_port=8001)
|
||||
|
||||
config: ConfigServer = config_server_workspace.get_config()
|
||||
|
||||
assert config.HTTPX_DEFAULT_TIMEOUT == 10
|
||||
assert config.OPEN_CROSS_DOMAIN is True
|
||||
assert config.DEFAULT_BIND_HOST == "0.0.0.0"
|
||||
assert config.WEBUI_SERVER_PORT == 8000
|
||||
assert config.API_SERVER_PORT == 8001
|
||||
config_server_workspace.clear()
|
||||
Loading…
x
Reference in New Issue
Block a user