2024-06-12 18:29:06 +08:00

152 lines
5.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
from dataclasses import dataclass
from pathlib import Path
import sys
import logging
from typing import Any, Optional, Dict
sys.path.append(str(Path(__file__).parent))
import _core_config as core_config
logger = logging.getLogger()
class ConfigServer(core_config.Config):
HTTPX_DEFAULT_TIMEOUT: Optional[float] = None
"""httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。"""
OPEN_CROSS_DOMAIN: Optional[bool] = None
"""API 是否开启跨域默认为False如果需要开启请设置为True"""
DEFAULT_BIND_HOST: Optional[str] = None
"""各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host"""
WEBUI_SERVER_PORT: Optional[int] = None
"""webui port"""
API_SERVER_PORT: Optional[int] = None
"""api port"""
WEBUI_SERVER: Optional[Dict[str, Any]] = None
"""webui.py server"""
API_SERVER: Optional[Dict[str, Any]] = None
"""api.py server"""
@classmethod
def class_name(cls) -> str:
return cls.__name__
def __str__(self):
return self.to_json()
@dataclass
class ConfigServerFactory(core_config.ConfigFactory[ConfigServer]):
"""Server 配置工厂类"""
def __init__(self):
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
self.HTTPX_DEFAULT_TIMEOUT = 300.0
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain
self.OPEN_CROSS_DOMAIN = True
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
self.DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
self.WEBUI_SERVER_PORT = 8501
self.API_SERVER_PORT = 7861
self.__init_server()
def __init_server(self):
# webui.py server
self.WEBUI_SERVER = {
"host": self.DEFAULT_BIND_HOST,
"port": self.WEBUI_SERVER_PORT,
}
# api.py server
self.API_SERVER = {
"host": self.DEFAULT_BIND_HOST,
"port": self.API_SERVER_PORT,
}
def httpx_default_timeout(self, timeout: float):
self.HTTPX_DEFAULT_TIMEOUT = timeout
def open_cross_domain(self, open_cross_domain: bool):
self.OPEN_CROSS_DOMAIN = open_cross_domain
def default_bind_host(self, default_bind_host: str):
self.DEFAULT_BIND_HOST = default_bind_host
self.__init_server()
def webui_server_port(self, webui_server_port: int):
self.WEBUI_SERVER_PORT = webui_server_port
self.__init_server()
def api_server_port(self, api_server_port: int):
self.API_SERVER_PORT = api_server_port
self.__init_server()
def get_config(self) -> ConfigServer:
config = ConfigServer()
config.HTTPX_DEFAULT_TIMEOUT = self.HTTPX_DEFAULT_TIMEOUT
config.OPEN_CROSS_DOMAIN = self.OPEN_CROSS_DOMAIN
config.DEFAULT_BIND_HOST = self.DEFAULT_BIND_HOST
config.WEBUI_SERVER_PORT = self.WEBUI_SERVER_PORT
config.API_SERVER_PORT = self.API_SERVER_PORT
config.WEBUI_SERVER = self.WEBUI_SERVER
config.API_SERVER = self.API_SERVER
return config
class ConfigServerWorkSpace(core_config.ConfigWorkSpace[ConfigServerFactory, ConfigServer]):
"""
工作空间的配置预设提供ConfigServer建造方法产生实例。
"""
config_factory_cls = ConfigServerFactory
def __init__(self):
super().__init__()
def _build_config_factory(self, config_json: Any) -> ConfigServerFactory:
_config_factory = self.config_factory_cls()
if config_json.get("HTTPX_DEFAULT_TIMEOUT") is not None:
_config_factory.httpx_default_timeout(config_json["HTTPX_DEFAULT_TIMEOUT"])
if config_json.get("OPEN_CROSS_DOMAIN") is not None:
_config_factory.open_cross_domain(config_json["OPEN_CROSS_DOMAIN"])
if config_json.get("DEFAULT_BIND_HOST") is not None:
_config_factory.default_bind_host(config_json["DEFAULT_BIND_HOST"])
if config_json.get("WEBUI_SERVER_PORT") is not None:
_config_factory.webui_server_port(config_json["WEBUI_SERVER_PORT"])
if config_json.get("API_SERVER_PORT") is not None:
_config_factory.api_server_port(config_json["API_SERVER_PORT"])
return _config_factory
@classmethod
def get_type(cls) -> str:
return ConfigServer.class_name()
def get_config(self) -> ConfigServer:
return self._config_factory.get_config()
def set_httpx_default_timeout(self, timeout: float):
self._config_factory.httpx_default_timeout(timeout)
self.store_config()
def set_open_cross_domain(self, open_cross_domain: bool):
self._config_factory.open_cross_domain(open_cross_domain)
self.store_config()
def set_default_bind_host(self, default_bind_host: str):
self._config_factory.default_bind_host(default_bind_host)
self.store_config()
def set_webui_server_port(self, webui_server_port: int):
self._config_factory.webui_server_port(webui_server_port)
self.store_config()
def set_api_server_port(self, api_server_port: int):
self._config_factory.api_server_port(api_server_port)
self.store_config()
config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace()