From 1ebd1aa8cc68fc1c38aca685c9910ca49e2a3766 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 9 Jun 2024 19:32:39 +0800 Subject: [PATCH] =?UTF-8?q?ConfigWorkSpace=20=E5=8D=95=E5=85=83=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chatchat/configs/__init__.py | 231 ++++-------------- .../chatchat/configs/_basic_config.py | 216 ++++++++++++---- .../tests/unit_tests/config/test_config.py | 38 ++- 3 files changed, 259 insertions(+), 226 deletions(-) diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index 32e41f11..5f23f5ab 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -10,162 +10,6 @@ import logging logger = logging.getLogger() -class ConfigBasic: - log_verbose: bool - CHATCHAT_ROOT: str - DATA_PATH: str - IMG_DIR: str - NLTK_DATA_PATH: str - LOG_FORMAT: str - LOG_PATH: str - MEDIA_PATH: str - BASE_TEMP_DIR: str - - -class ConfigBasicFactory: - """Basic config for ChatChat """ - - def __init__(self): - # 日志格式 - self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" - logging.basicConfig(format=self.LOG_FORMAT) - self.LOG_VERBOSE = False - self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) - # 用户数据根目录 - self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") - self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") - if not os.path.exists(self._DATA_PATH): - os.mkdir(self.DATA_PATH) - - self._init_data_dir() - - # 项目相关图片 - self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img") - if not os.path.exists(self.IMG_DIR): - os.mkdir(self.IMG_DIR) - - def log_verbose(self, verbose: bool): - self.LOG_VERBOSE = verbose - - def chatchat_root(self, root: str): - self.CHATCHAT_ROOT = root - - def data_path(self, path: str): - self.DATA_PATH = path - if not os.path.exists(self.DATA_PATH): - os.mkdir(self.DATA_PATH) - # 复制_DATA_PATH数据到DATA_PATH - os.system(f"cp -r {self._DATA_PATH} {self.DATA_PATH}") - - self._init_data_dir() - - def log_format(self, log_format: str): - self.LOG_FORMAT = log_format - logging.basicConfig(format=self.LOG_FORMAT) - - def _init_data_dir(self): - logger.info(f"Init data dir: {self.DATA_PATH}") - # nltk 模型存储路径 - self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data") - import nltk - nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path - # 日志存储路径 - self.LOG_PATH = os.path.join(self.DATA_PATH, "logs") - if not os.path.exists(self.LOG_PATH): - os.mkdir(self.LOG_PATH) - - # 模型生成内容(图片、视频、音频等)保存位置 - self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media") - if not os.path.exists(self.MEDIA_PATH): - os.mkdir(self.MEDIA_PATH) - os.mkdir(os.path.join(self.MEDIA_PATH, "image")) - os.mkdir(os.path.join(self.MEDIA_PATH, "audio")) - os.mkdir(os.path.join(self.MEDIA_PATH, "video")) - - # 临时文件目录,主要用于文件对话 - self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp") - if not os.path.exists(self.BASE_TEMP_DIR): - os.mkdir(self.BASE_TEMP_DIR) - - logger.info(f"Init data dir: {self.DATA_PATH} success.") - - def get_config(self) -> ConfigBasic: - config = ConfigBasic() - config.log_verbose = self.LOG_VERBOSE - config.CHATCHAT_ROOT = self.CHATCHAT_ROOT - config.DATA_PATH = self.DATA_PATH - config.IMG_DIR = self.IMG_DIR - config.NLTK_DATA_PATH = self.NLTK_DATA_PATH - config.LOG_FORMAT = self.LOG_FORMAT - config.LOG_PATH = self.LOG_PATH - config.MEDIA_PATH = self.MEDIA_PATH - config.BASE_TEMP_DIR = self.BASE_TEMP_DIR - return config - - -class ConfigWorkSpace: - """ - 工作空间的配置预设,提供ConfigBasic建造方法产生实例。 - 该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等 - 工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。 - 注意:不存在则读取默认 - """ - _config_factory: ConfigBasicFactory = ConfigBasicFactory() - - def __init__(self): - self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace") - if not os.path.exists(self.workspace): - os.makedirs(self.workspace, exist_ok=True) - self.workspace_config = os.path.join(self.workspace, "workspace_config.json") - # 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化 - with open(self.workspace_config, "w") as f: - config_json = json.loads(f.read()) - - if config_json: - - _config_factory = ConfigBasicFactory() - if config_json.get("log_verbose"): - _config_factory.log_verbose(config_json.get("log_verbose")) - if config_json.get("CHATCHAT_ROOT"): - _config_factory.chatchat_root(config_json.get("CHATCHAT_ROOT")) - if config_json.get("DATA_PATH"): - _config_factory.data_path(config_json.get("DATA_PATH")) - if config_json.get("LOG_FORMAT"): - _config_factory.log_format(config_json.get("LOG_FORMAT")) - - self._config_factory = _config_factory - - def get_config(self) -> ConfigBasic: - return self._config_factory.get_config() - - def set_log_verbose(self, verbose: bool): - self._config_factory.log_verbose(verbose) - self._store_config() - - def set_chatchat_root(self, root: str): - self._config_factory.chatchat_root(root) - self._store_config() - - def set_data_path(self, path: str): - self._config_factory.data_path(path) - self._store_config() - - def set_log_format(self, log_format: str): - self._config_factory.log_format(log_format) - self._store_config() - - def _store_config(self): - with open(self.workspace_config, "w") as f: - config = self._config_factory.get_config() - config_json = { - "log_verbose": config.log_verbose, - "CHATCHAT_ROOT": config.CHATCHAT_ROOT, - "DATA_PATH": config.DATA_PATH, - "LOG_FORMAT": config.LOG_FORMAT - } - f.write(json.dumps(config_json, indent=4, ensure_ascii=False)) - - def _load_mod(mod, attr): attr_cfg = None for name, obj in vars(mod).items(): @@ -252,76 +96,102 @@ CONFIG_IMPORTS = { } +def _import_ConfigBasic() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigBasic = load_mod(basic_config_load.get("module"), "ConfigBasic") + + return ConfigBasic + + +def _import_ConfigBasicFactory() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigBasicFactory = load_mod(basic_config_load.get("module"), "ConfigBasicFactory") + + return ConfigBasicFactory + + +def _import_ConfigWorkSpace() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace") + + return ConfigWorkSpace + + def _import_log_verbose() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - log_verbose = load_mod(basic_config_load.get("module"), "log_verbose") - - return log_verbose + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace.get_config().log_verbose def _import_chatchat_root() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - CHATCHAT_ROOT = load_mod(basic_config_load.get("module"), "CHATCHAT_ROOT") + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return CHATCHAT_ROOT + return config_workspace.get_config().CHATCHAT_ROOT def _import_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - DATA_PATH = load_mod(basic_config_load.get("module"), "DATA_PATH") - return DATA_PATH + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace.get_config().DATA_PATH def _import_img_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - IMG_DIR = load_mod(basic_config_load.get("module"), "IMG_DIR") - return IMG_DIR + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().IMG_DIR def _import_nltk_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - NLTK_DATA_PATH = load_mod(basic_config_load.get("module"), "NLTK_DATA_PATH") + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return NLTK_DATA_PATH + return config_workspace.get_config().NLTK_DATA_PATH def _import_log_format() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - LOG_FORMAT = load_mod(basic_config_load.get("module"), "LOG_FORMAT") - return LOG_FORMAT + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().LOG_FORMAT def _import_log_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - LOG_PATH = load_mod(basic_config_load.get("module"), "LOG_PATH") - return LOG_PATH + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().LOG_PATH def _import_media_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - MEDIA_PATH = load_mod(basic_config_load.get("module"), "MEDIA_PATH") - return MEDIA_PATH + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().MEDIA_PATH def _import_base_temp_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - BASE_TEMP_DIR = load_mod(basic_config_load.get("module"), "BASE_TEMP_DIR") - - return BASE_TEMP_DIR + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace.get_config().BASE_TEMP_DIR def _import_default_knowledge_base() -> Any: @@ -637,7 +507,13 @@ def _import_api_server() -> Any: def __getattr__(name: str) -> Any: - if name == "log_verbose": + if name == "ConfigBasic": + return _import_ConfigBasic() + elif name == "ConfigBasicFactory": + return _import_ConfigBasicFactory() + elif name == "ConfigWorkSpace": + return _import_ConfigWorkSpace() + elif name == "log_verbose": return _import_log_verbose() elif name == "CHATCHAT_ROOT": return _import_chatchat_root() @@ -785,7 +661,6 @@ __all__ = [ "WEBUI_SERVER", "API_SERVER", - "ConfigBasic", "ConfigBasicFactory", "ConfigWorkSpace", diff --git a/libs/chatchat-server/chatchat/configs/_basic_config.py b/libs/chatchat-server/chatchat/configs/_basic_config.py index 46eec939..03655ce2 100644 --- a/libs/chatchat-server/chatchat/configs/_basic_config.py +++ b/libs/chatchat-server/chatchat/configs/_basic_config.py @@ -1,55 +1,177 @@ -import logging import os +import json from pathlib import Path +import logging -import langchain - - -# 是否显示详细日志 -log_verbose = False -langchain.verbose = False - -# 通常情况下不需要更改以下内容 - -# chatchat 项目根目录 -CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) - -# 用户数据根目录 -DATA_PATH = os.path.join(CHATCHAT_ROOT, "data") -if not os.path.exists(DATA_PATH): - os.mkdir(DATA_PATH) - -# 项目相关图片 -IMG_DIR = os.path.join(CHATCHAT_ROOT, "img") -if not os.path.exists(IMG_DIR): - os.mkdir(IMG_DIR) - -# nltk 模型存储路径 -NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data") -import nltk -nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path - -# 日志格式 -LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" logger = logging.getLogger() -logger.setLevel(logging.INFO) -logging.basicConfig(format=LOG_FORMAT) -# 日志存储路径 -LOG_PATH = os.path.join(DATA_PATH, "logs") -if not os.path.exists(LOG_PATH): - os.mkdir(LOG_PATH) +class ConfigBasic: + log_verbose: bool + """是否开启日志详细信息""" + CHATCHAT_ROOT: str + """项目根目录""" + DATA_PATH: str + """用户数据根目录""" + IMG_DIR: str + """项目相关图片""" + NLTK_DATA_PATH: str + """nltk 模型存储路径""" + LOG_FORMAT: str + """日志格式""" + LOG_PATH: str + """日志存储路径""" + MEDIA_PATH: str + """模型生成内容(图片、视频、音频等)保存位置""" + BASE_TEMP_DIR: str + """临时文件目录,主要用于文件对话""" -# 模型生成内容(图片、视频、音频等)保存位置 -MEDIA_PATH = os.path.join(DATA_PATH, "media") -if not os.path.exists(MEDIA_PATH): - os.mkdir(MEDIA_PATH) - os.mkdir(os.path.join(MEDIA_PATH, "image")) - os.mkdir(os.path.join(MEDIA_PATH, "audio")) - os.mkdir(os.path.join(MEDIA_PATH, "video")) -# 临时文件目录,主要用于文件对话 -BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp") -if not os.path.exists(BASE_TEMP_DIR): - os.mkdir(BASE_TEMP_DIR) +class ConfigBasicFactory: + """Basic config for ChatChat """ + + def __init__(self): + # 日志格式 + self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" + logging.basicConfig(format=self.LOG_FORMAT) + self.LOG_VERBOSE = False + self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) + # 用户数据根目录 + self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") + self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") + if not os.path.exists(self._DATA_PATH): + os.makedirs(self._DATA_PATH, exist_ok=True) + + self._init_data_dir() + + # 项目相关图片 + self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img") + if not os.path.exists(self.IMG_DIR): + os.makedirs(self.IMG_DIR, exist_ok=True) + + def log_verbose(self, verbose: bool): + self.LOG_VERBOSE = verbose + + def data_path(self, path: str): + self.DATA_PATH = path + if not os.path.exists(self.DATA_PATH): + os.makedirs(self.DATA_PATH, exist_ok=True) + # 复制_DATA_PATH数据到DATA_PATH + if self._DATA_PATH != self.DATA_PATH: + os.system(f"cp -r {self._DATA_PATH}/* {self.DATA_PATH}") + + self._init_data_dir() + + def log_format(self, log_format: str): + self.LOG_FORMAT = log_format + logging.basicConfig(format=self.LOG_FORMAT) + + def _init_data_dir(self): + logger.info(f"Init data dir: {self.DATA_PATH}") + # nltk 模型存储路径 + self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data") + import nltk + nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path + # 日志存储路径 + self.LOG_PATH = os.path.join(self.DATA_PATH, "logs") + if not os.path.exists(self.LOG_PATH): + os.makedirs(self.LOG_PATH, exist_ok=True) + + # 模型生成内容(图片、视频、音频等)保存位置 + self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media") + if not os.path.exists(self.MEDIA_PATH): + os.makedirs(self.MEDIA_PATH, exist_ok=True) + os.makedirs(os.path.join(self.MEDIA_PATH, "image"), exist_ok=True) + os.makedirs(os.path.join(self.MEDIA_PATH, "audio"), exist_ok=True) + os.makedirs(os.path.join(self.MEDIA_PATH, "video"), exist_ok=True) + + # 临时文件目录,主要用于文件对话 + self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp") + if not os.path.exists(self.BASE_TEMP_DIR): + os.makedirs(self.BASE_TEMP_DIR, exist_ok=True) + + logger.info(f"Init data dir: {self.DATA_PATH} success.") + + def get_config(self) -> ConfigBasic: + config = ConfigBasic() + config.log_verbose = self.LOG_VERBOSE + config.CHATCHAT_ROOT = self.CHATCHAT_ROOT + config.DATA_PATH = self.DATA_PATH + config.IMG_DIR = self.IMG_DIR + config.NLTK_DATA_PATH = self.NLTK_DATA_PATH + config.LOG_FORMAT = self.LOG_FORMAT + config.LOG_PATH = self.LOG_PATH + config.MEDIA_PATH = self.MEDIA_PATH + config.BASE_TEMP_DIR = self.BASE_TEMP_DIR + return config + + +class ConfigWorkSpace: + """ + 工作空间的配置预设,提供ConfigBasic建造方法产生实例。 + 该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等 + 工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。 + 注意:不存在则读取默认 + """ + _config_factory: ConfigBasicFactory = ConfigBasicFactory() + + def __init__(self): + self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace") + if not os.path.exists(self.workspace): + os.makedirs(self.workspace, exist_ok=True) + self.workspace_config = os.path.join(self.workspace, "workspace_config.json") + # 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化 + + config_json = self._load_config() + + if config_json: + + _config_factory = ConfigBasicFactory() + if config_json.get("log_verbose"): + _config_factory.log_verbose(config_json.get("log_verbose")) + if config_json.get("DATA_PATH"): + _config_factory.data_path(config_json.get("DATA_PATH")) + if config_json.get("LOG_FORMAT"): + _config_factory.log_format(config_json.get("LOG_FORMAT")) + + self._config_factory = _config_factory + + def get_config(self) -> ConfigBasic: + return self._config_factory.get_config() + + def set_log_verbose(self, verbose: bool): + self._config_factory.log_verbose(verbose) + self._store_config() + + def set_data_path(self, path: str): + self._config_factory.data_path(path) + self._store_config() + + def set_log_format(self, log_format: str): + self._config_factory.log_format(log_format) + self._store_config() + + def clear(self): + logger.info("Clear workspace config.") + os.remove(self.workspace_config) + + def _load_config(self): + try: + with open(self.workspace_config, "r") as f: + return json.loads(f.read()) + except FileNotFoundError: + return None + + def _store_config(self): + with open(self.workspace_config, "w") as f: + config = self._config_factory.get_config() + config_json = { + "log_verbose": config.log_verbose, + "CHATCHAT_ROOT": config.CHATCHAT_ROOT, + "DATA_PATH": config.DATA_PATH, + "LOG_FORMAT": config.LOG_FORMAT + } + f.write(json.dumps(config_json, indent=4, ensure_ascii=False)) + + +config_workspace: ConfigWorkSpace = ConfigWorkSpace() diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py index 83739acc..96fbd5dc 100644 --- a/libs/chatchat-server/tests/unit_tests/config/test_config.py +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -1,4 +1,6 @@ -from chatchat.configs import ConfigBasicFactory, ConfigBasic +from pathlib import Path + +from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace import os @@ -18,3 +20,37 @@ def test_config_factory_def(): assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) assert os.path.exists(os.path.join(config.MEDIA_PATH, "video")) + + +def test_workspace(): + config_workspace = ConfigWorkSpace() + assert config_workspace.get_config() is not None + base_root = os.path.join(Path(__file__).absolute().parent, "chatchat") + config_workspace.set_data_path(os.path.join(base_root, "data")) + config_workspace.set_log_verbose(True) + config_workspace.set_log_format(" %(message)s") + + config: ConfigBasic = config_workspace.get_config() + assert config.log_verbose is True + assert config.DATA_PATH == os.path.join(base_root, "data") + assert config.IMG_DIR is not None + assert config.NLTK_DATA_PATH == os.path.join(base_root, "data", "nltk_data") + assert config.LOG_FORMAT == " %(message)s" + assert config.LOG_PATH == os.path.join(base_root, "data", "logs") + assert config.MEDIA_PATH == os.path.join(base_root, "data", "media") + + assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "video")) + config_workspace.clear() + + +def test_workspace_default(): + from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH) + assert log_verbose is False + assert DATA_PATH is not None + assert IMG_DIR is not None + assert NLTK_DATA_PATH is not None + assert LOG_FORMAT is not None + assert LOG_PATH is not None + assert MEDIA_PATH is not None