diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index bca795be..32e41f11 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -1,13 +1,171 @@ import importlib import importlib.util import os +from pathlib import Path from typing import Dict, Any +import json import logging logger = logging.getLogger() +class ConfigBasic: + log_verbose: bool + CHATCHAT_ROOT: str + DATA_PATH: str + IMG_DIR: str + NLTK_DATA_PATH: str + LOG_FORMAT: str + LOG_PATH: str + MEDIA_PATH: str + BASE_TEMP_DIR: str + + +class ConfigBasicFactory: + """Basic config for ChatChat """ + + def __init__(self): + # 日志格式 + self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" + logging.basicConfig(format=self.LOG_FORMAT) + self.LOG_VERBOSE = False + self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) + # 用户数据根目录 + self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") + self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") + if not os.path.exists(self._DATA_PATH): + os.mkdir(self.DATA_PATH) + + self._init_data_dir() + + # 项目相关图片 + self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img") + if not os.path.exists(self.IMG_DIR): + os.mkdir(self.IMG_DIR) + + def log_verbose(self, verbose: bool): + self.LOG_VERBOSE = verbose + + def chatchat_root(self, root: str): + self.CHATCHAT_ROOT = root + + def data_path(self, path: str): + self.DATA_PATH = path + if not os.path.exists(self.DATA_PATH): + os.mkdir(self.DATA_PATH) + # 复制_DATA_PATH数据到DATA_PATH + os.system(f"cp -r {self._DATA_PATH} {self.DATA_PATH}") + + self._init_data_dir() + + def log_format(self, log_format: str): + self.LOG_FORMAT = log_format + logging.basicConfig(format=self.LOG_FORMAT) + + def _init_data_dir(self): + logger.info(f"Init data dir: {self.DATA_PATH}") + # nltk 模型存储路径 + self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data") + import nltk + nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path + # 日志存储路径 + self.LOG_PATH = os.path.join(self.DATA_PATH, "logs") + if not os.path.exists(self.LOG_PATH): + os.mkdir(self.LOG_PATH) + + # 模型生成内容(图片、视频、音频等)保存位置 + self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media") + if not os.path.exists(self.MEDIA_PATH): + os.mkdir(self.MEDIA_PATH) + os.mkdir(os.path.join(self.MEDIA_PATH, "image")) + os.mkdir(os.path.join(self.MEDIA_PATH, "audio")) + os.mkdir(os.path.join(self.MEDIA_PATH, "video")) + + # 临时文件目录,主要用于文件对话 + self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp") + if not os.path.exists(self.BASE_TEMP_DIR): + os.mkdir(self.BASE_TEMP_DIR) + + logger.info(f"Init data dir: {self.DATA_PATH} success.") + + def get_config(self) -> ConfigBasic: + config = ConfigBasic() + config.log_verbose = self.LOG_VERBOSE + config.CHATCHAT_ROOT = self.CHATCHAT_ROOT + config.DATA_PATH = self.DATA_PATH + config.IMG_DIR = self.IMG_DIR + config.NLTK_DATA_PATH = self.NLTK_DATA_PATH + config.LOG_FORMAT = self.LOG_FORMAT + config.LOG_PATH = self.LOG_PATH + config.MEDIA_PATH = self.MEDIA_PATH + config.BASE_TEMP_DIR = self.BASE_TEMP_DIR + return config + + +class ConfigWorkSpace: + """ + 工作空间的配置预设,提供ConfigBasic建造方法产生实例。 + 该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等 + 工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。 + 注意:不存在则读取默认 + """ + _config_factory: ConfigBasicFactory = ConfigBasicFactory() + + def __init__(self): + self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace") + if not os.path.exists(self.workspace): + os.makedirs(self.workspace, exist_ok=True) + self.workspace_config = os.path.join(self.workspace, "workspace_config.json") + # 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化 + with open(self.workspace_config, "w") as f: + config_json = json.loads(f.read()) + + if config_json: + + _config_factory = ConfigBasicFactory() + if config_json.get("log_verbose"): + _config_factory.log_verbose(config_json.get("log_verbose")) + if config_json.get("CHATCHAT_ROOT"): + _config_factory.chatchat_root(config_json.get("CHATCHAT_ROOT")) + if config_json.get("DATA_PATH"): + _config_factory.data_path(config_json.get("DATA_PATH")) + if config_json.get("LOG_FORMAT"): + _config_factory.log_format(config_json.get("LOG_FORMAT")) + + self._config_factory = _config_factory + + def get_config(self) -> ConfigBasic: + return self._config_factory.get_config() + + def set_log_verbose(self, verbose: bool): + self._config_factory.log_verbose(verbose) + self._store_config() + + def set_chatchat_root(self, root: str): + self._config_factory.chatchat_root(root) + self._store_config() + + def set_data_path(self, path: str): + self._config_factory.data_path(path) + self._store_config() + + def set_log_format(self, log_format: str): + self._config_factory.log_format(log_format) + self._store_config() + + def _store_config(self): + with open(self.workspace_config, "w") as f: + config = self._config_factory.get_config() + config_json = { + "log_verbose": config.log_verbose, + "CHATCHAT_ROOT": config.CHATCHAT_ROOT, + "DATA_PATH": config.DATA_PATH, + "LOG_FORMAT": config.LOG_FORMAT + } + f.write(json.dumps(config_json, indent=4, ensure_ascii=False)) + + def _load_mod(mod, attr): attr_cfg = None for name, obj in vars(mod).items(): @@ -38,7 +196,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict: ) user_import = False if user_import: - # Dynamic loading {config}.py file py_path = os.path.join(user_config_path, import_config_mod + ".py") spec = importlib.util.spec_from_file_location( @@ -69,7 +226,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict: ) raise RuntimeError(f"Failed to load user config from {user_config_path}") # 当前文件路径 - py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py") + py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py") spec = importlib.util.spec_from_file_location(f"*", py_path) @@ -118,6 +275,7 @@ def _import_data_path() -> Any: return DATA_PATH + def _import_img_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") @@ -285,6 +443,7 @@ def _import_db_root_path() -> Any: return DB_ROOT_PATH + def _import_sqlalchemy_database_uri() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") @@ -627,4 +786,8 @@ __all__ = [ "API_SERVER", + "ConfigBasic", + "ConfigBasicFactory", + "ConfigWorkSpace", + ] diff --git a/libs/chatchat-server/tests/conftest.py b/libs/chatchat-server/tests/conftest.py new file mode 100644 index 00000000..031384e8 --- /dev/null +++ b/libs/chatchat-server/tests/conftest.py @@ -0,0 +1,90 @@ +"""Configuration for unit tests.""" +import logging +from importlib import util +from typing import Dict, List, Sequence + +import pytest +from pytest import Config, Function, Parser + + +def pytest_addoption(parser: Parser) -> None: + """Add custom command line options to pytest.""" + parser.addoption( + "--only-extended", + action="store_true", + help="Only run extended tests. Does not allow skipping any extended tests.", + ) + parser.addoption( + "--only-core", + action="store_true", + help="Only run core tests. Never runs any extended tests.", + ) + + +def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: + """Add implementations for handling custom markers. + + At the moment, this adds support for a custom `requires` marker. + + The `requires` marker is used to denote tests that require one or more packages + to be installed to run. If the package is not installed, the test is skipped. + + The `requires` marker syntax is: + + .. code-block:: python + + @pytest.mark.requires("package1", "package2") + def test_something(): + ... + """ + # Mapping from the name of a package to whether it is installed or not. + # Used to avoid repeated calls to `util.find_spec` + required_pkgs_info: Dict[str, bool] = {} + + only_extended = config.getoption("--only-extended") or False + only_core = config.getoption("--only-core") or False + + if only_extended and only_core: + raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") + + for item in items: + requires_marker = item.get_closest_marker("requires") + if requires_marker is not None: + if only_core: + item.add_marker(pytest.mark.skip(reason="Skipping not a core test.")) + continue + + # Iterate through the list of required packages + required_pkgs = requires_marker.args + for pkg in required_pkgs: + # If we haven't yet checked whether the pkg is installed + # let's check it and store the result. + if pkg not in required_pkgs_info: + try: + installed = util.find_spec(pkg) is not None + except Exception: + installed = False + required_pkgs_info[pkg] = installed + + if not required_pkgs_info[pkg]: + if only_extended: + pytest.fail( + f"Package `{pkg}` is not installed but is required for " + f"extended tests. Please install the given package and " + f"try again.", + ) + + else: + # If the package is not installed, we immediately break + # and mark the test as skipped. + item.add_marker( + pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") + ) + break + else: + if only_extended: + item.add_marker( + pytest.mark.skip(reason="Skipping not an extended test.") + ) + + diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py new file mode 100644 index 00000000..83739acc --- /dev/null +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -0,0 +1,20 @@ +from chatchat.configs import ConfigBasicFactory, ConfigBasic +import os + + +def test_config_factory_def(): + test_config_factory = ConfigBasicFactory() + config: ConfigBasic = test_config_factory.get_config() + assert config is not None + assert config.log_verbose is False + assert config.CHATCHAT_ROOT is not None + assert config.DATA_PATH is not None + assert config.IMG_DIR is not None + assert config.NLTK_DATA_PATH is not None + assert config.LOG_FORMAT is not None + assert config.LOG_PATH is not None + assert config.MEDIA_PATH is not None + + assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))