mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-26 16:53:36 +08:00
152 lines
5.4 KiB
Python
152 lines
5.4 KiB
Python
import os
|
||
import json
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
import sys
|
||
import logging
|
||
from typing import Any, Optional, Dict
|
||
|
||
sys.path.append(str(Path(__file__).parent))
|
||
import _core_config as core_config
|
||
|
||
logger = logging.getLogger()
|
||
|
||
|
||
class ConfigServer(core_config.Config):
|
||
HTTPX_DEFAULT_TIMEOUT: Optional[float] = None
|
||
"""httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。"""
|
||
OPEN_CROSS_DOMAIN: Optional[bool] = None
|
||
"""API 是否开启跨域,默认为False,如果需要开启,请设置为True"""
|
||
DEFAULT_BIND_HOST: Optional[str] = None
|
||
"""各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host"""
|
||
WEBUI_SERVER_PORT: Optional[int] = None
|
||
"""webui port"""
|
||
API_SERVER_PORT: Optional[int] = None
|
||
"""api port"""
|
||
WEBUI_SERVER: Optional[Dict[str, Any]] = None
|
||
"""webui.py server"""
|
||
API_SERVER: Optional[Dict[str, Any]] = None
|
||
"""api.py server"""
|
||
|
||
@classmethod
|
||
def class_name(cls) -> str:
|
||
return cls.__name__
|
||
|
||
def __str__(self):
|
||
return self.to_json()
|
||
|
||
|
||
@dataclass
|
||
class ConfigServerFactory(core_config.ConfigFactory[ConfigServer]):
|
||
"""Server 配置工厂类"""
|
||
|
||
def __init__(self):
|
||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||
self.HTTPX_DEFAULT_TIMEOUT = 300.0
|
||
|
||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||
# is open cross domain
|
||
self.OPEN_CROSS_DOMAIN = True
|
||
|
||
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||
self.DEFAULT_BIND_HOST = "127.0.0.1" if sys.platform != "win32" else "127.0.0.1"
|
||
self.WEBUI_SERVER_PORT = 8501
|
||
self.API_SERVER_PORT = 7861
|
||
self.__init_server()
|
||
|
||
def __init_server(self):
|
||
# webui.py server
|
||
self.WEBUI_SERVER = {
|
||
"host": self.DEFAULT_BIND_HOST,
|
||
"port": self.WEBUI_SERVER_PORT,
|
||
}
|
||
|
||
# api.py server
|
||
self.API_SERVER = {
|
||
"host": self.DEFAULT_BIND_HOST,
|
||
"port": self.API_SERVER_PORT,
|
||
}
|
||
|
||
def httpx_default_timeout(self, timeout: float):
|
||
self.HTTPX_DEFAULT_TIMEOUT = timeout
|
||
|
||
def open_cross_domain(self, open_cross_domain: bool):
|
||
self.OPEN_CROSS_DOMAIN = open_cross_domain
|
||
|
||
def default_bind_host(self, default_bind_host: str):
|
||
self.DEFAULT_BIND_HOST = default_bind_host
|
||
self.__init_server()
|
||
|
||
def webui_server_port(self, webui_server_port: int):
|
||
self.WEBUI_SERVER_PORT = webui_server_port
|
||
self.__init_server()
|
||
|
||
def api_server_port(self, api_server_port: int):
|
||
self.API_SERVER_PORT = api_server_port
|
||
self.__init_server()
|
||
|
||
def get_config(self) -> ConfigServer:
|
||
config = ConfigServer()
|
||
config.HTTPX_DEFAULT_TIMEOUT = self.HTTPX_DEFAULT_TIMEOUT
|
||
config.OPEN_CROSS_DOMAIN = self.OPEN_CROSS_DOMAIN
|
||
config.DEFAULT_BIND_HOST = self.DEFAULT_BIND_HOST
|
||
config.WEBUI_SERVER_PORT = self.WEBUI_SERVER_PORT
|
||
config.API_SERVER_PORT = self.API_SERVER_PORT
|
||
config.WEBUI_SERVER = self.WEBUI_SERVER
|
||
config.API_SERVER = self.API_SERVER
|
||
return config
|
||
|
||
|
||
class ConfigServerWorkSpace(core_config.ConfigWorkSpace[ConfigServerFactory, ConfigServer]):
|
||
"""
|
||
工作空间的配置预设,提供ConfigServer建造方法产生实例。
|
||
"""
|
||
config_factory_cls = ConfigServerFactory
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
|
||
def _build_config_factory(self, config_json: Any) -> ConfigServerFactory:
|
||
_config_factory = self.config_factory_cls()
|
||
if config_json.get("HTTPX_DEFAULT_TIMEOUT") is not None:
|
||
_config_factory.httpx_default_timeout(config_json["HTTPX_DEFAULT_TIMEOUT"])
|
||
if config_json.get("OPEN_CROSS_DOMAIN") is not None:
|
||
_config_factory.open_cross_domain(config_json["OPEN_CROSS_DOMAIN"])
|
||
if config_json.get("DEFAULT_BIND_HOST") is not None:
|
||
_config_factory.default_bind_host(config_json["DEFAULT_BIND_HOST"])
|
||
if config_json.get("WEBUI_SERVER_PORT") is not None:
|
||
_config_factory.webui_server_port(config_json["WEBUI_SERVER_PORT"])
|
||
if config_json.get("API_SERVER_PORT") is not None:
|
||
_config_factory.api_server_port(config_json["API_SERVER_PORT"])
|
||
|
||
return _config_factory
|
||
|
||
@classmethod
|
||
def get_type(cls) -> str:
|
||
return ConfigServer.class_name()
|
||
|
||
def get_config(self) -> ConfigServer:
|
||
return self._config_factory.get_config()
|
||
|
||
def set_httpx_default_timeout(self, timeout: float):
|
||
self._config_factory.httpx_default_timeout(timeout)
|
||
self.store_config()
|
||
|
||
def set_open_cross_domain(self, open_cross_domain: bool):
|
||
self._config_factory.open_cross_domain(open_cross_domain)
|
||
self.store_config()
|
||
|
||
def set_default_bind_host(self, default_bind_host: str):
|
||
self._config_factory.default_bind_host(default_bind_host)
|
||
self.store_config()
|
||
|
||
def set_webui_server_port(self, webui_server_port: int):
|
||
self._config_factory.webui_server_port(webui_server_port)
|
||
self.store_config()
|
||
|
||
def set_api_server_port(self, api_server_port: int):
|
||
self._config_factory.api_server_port(api_server_port)
|
||
self.store_config()
|
||
|
||
|
||
config_server_workspace: ConfigServerWorkSpace = ConfigServerWorkSpace() |