diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index c1191b08..b2408456 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -112,92 +112,95 @@ def _import_ConfigBasicFactory() -> Any: return ConfigBasicFactory -def _import_ConfigWorkSpace() -> Any: +def _import_ConfigBasicWorkSpace() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace") + ConfigBasicWorkSpace = load_mod(basic_config_load.get("module"), "ConfigBasicWorkSpace") - return ConfigWorkSpace + return ConfigBasicWorkSpace -def _import_config_workspace() -> Any: +def _import_config_basic_workspace() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return config_workspace + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + return config_basic_workspace + def _import_log_verbose() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return config_workspace.get_config().log_verbose + + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + return config_basic_workspace.get_config().log_verbose def _import_chatchat_root() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return config_workspace.get_config().CHATCHAT_ROOT + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + + return config_basic_workspace.get_config().CHATCHAT_ROOT def _import_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return config_workspace.get_config().DATA_PATH + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + return config_basic_workspace.get_config().DATA_PATH def _import_img_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") - return config_workspace.get_config().IMG_DIR + return config_basic_workspace.get_config().IMG_DIR def _import_nltk_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") - return config_workspace.get_config().NLTK_DATA_PATH + return config_basic_workspace.get_config().NLTK_DATA_PATH def _import_log_format() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") - return config_workspace.get_config().LOG_FORMAT + return config_basic_workspace.get_config().LOG_FORMAT def _import_log_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") - return config_workspace.get_config().LOG_PATH + return config_basic_workspace.get_config().LOG_PATH def _import_media_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") - return config_workspace.get_config().MEDIA_PATH + return config_basic_workspace.get_config().MEDIA_PATH def _import_base_temp_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return config_workspace.get_config().BASE_TEMP_DIR + config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace") + return config_basic_workspace.get_config().BASE_TEMP_DIR def _import_default_knowledge_base() -> Any: @@ -517,10 +520,10 @@ def __getattr__(name: str) -> Any: return _import_ConfigBasic() elif name == "ConfigBasicFactory": return _import_ConfigBasicFactory() - elif name == "ConfigWorkSpace": - return _import_ConfigWorkSpace() - elif name == "config_workspace": - return _import_config_workspace() + elif name == "ConfigBasicWorkSpace": + return _import_ConfigBasicWorkSpace() + elif name == "config_basic_workspace": + return _import_config_basic_workspace() elif name == "log_verbose": return _import_log_verbose() elif name == "CHATCHAT_ROOT": @@ -621,7 +624,7 @@ VERSION = "v0.3.0-preview" __all__ = [ "VERSION", - "config_workspace", + "config_basic_workspace", "log_verbose", "CHATCHAT_ROOT", "DATA_PATH", @@ -672,6 +675,6 @@ __all__ = [ "ConfigBasic", "ConfigBasicFactory", - "ConfigWorkSpace", + "ConfigBasicWorkSpace", ] diff --git a/libs/chatchat-server/chatchat/configs/_basic_config.py b/libs/chatchat-server/chatchat/configs/_basic_config.py index 61108bde..fa5c150e 100644 --- a/libs/chatchat-server/chatchat/configs/_basic_config.py +++ b/libs/chatchat-server/chatchat/configs/_basic_config.py @@ -1,36 +1,49 @@ import os import json +from dataclasses import dataclass from pathlib import Path +import sys import logging +from typing import Any, Optional + +from chatchat.configs._core_config import CF + +sys.path.append(str(Path(__file__).parent)) +import _core_config as core_config logger = logging.getLogger() -class ConfigBasic: - log_verbose: bool +class ConfigBasic(core_config.Config): + log_verbose: Optional[bool] = None """是否开启日志详细信息""" - CHATCHAT_ROOT: str + CHATCHAT_ROOT: Optional[str] = None """项目根目录""" - DATA_PATH: str + DATA_PATH: Optional[str] = None """用户数据根目录""" - IMG_DIR: str + IMG_DIR: Optional[str] = None """项目相关图片""" - NLTK_DATA_PATH: str + NLTK_DATA_PATH: Optional[str] = None """nltk 模型存储路径""" - LOG_FORMAT: str + LOG_FORMAT: Optional[str] = None """日志格式""" - LOG_PATH: str + LOG_PATH: Optional[str] = None """日志存储路径""" - MEDIA_PATH: str + MEDIA_PATH: Optional[str] = None """模型生成内容(图片、视频、音频等)保存位置""" - BASE_TEMP_DIR: str + BASE_TEMP_DIR: Optional[str] = None """临时文件目录,主要用于文件对话""" + @classmethod + def class_name(cls) -> str: + return cls.__name__ + def __str__(self): return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})" -class ConfigBasicFactory: +@dataclass +class ConfigBasicFactory(core_config.ConfigFactory[ConfigBasic]): """Basic config for ChatChat """ def __init__(self): @@ -109,72 +122,50 @@ class ConfigBasicFactory: return config -class ConfigWorkSpace: +class ConfigBasicWorkSpace(core_config.ConfigWorkSpace[ConfigBasicFactory, ConfigBasic]): """ 工作空间的配置预设,提供ConfigBasic建造方法产生实例。 - 该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等 - 工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。 - 注意:不存在则读取默认 """ - _config_factory: ConfigBasicFactory = ConfigBasicFactory() + config_factory_cls = ConfigBasicFactory + + def _build_config_factory(self, config_json: Any) -> ConfigBasicFactory: + + _config_factory = self.config_factory_cls() + + if config_json.get("log_verbose"): + _config_factory.log_verbose(config_json.get("log_verbose")) + if config_json.get("DATA_PATH"): + _config_factory.data_path(config_json.get("DATA_PATH")) + if config_json.get("LOG_FORMAT"): + _config_factory.log_format(config_json.get("LOG_FORMAT")) + + return _config_factory + + @classmethod + def get_type(cls) -> str: + return ConfigBasic.class_name() def __init__(self): - self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace") - if not os.path.exists(self.workspace): - os.makedirs(self.workspace, exist_ok=True) - self.workspace_config = os.path.join(self.workspace, "workspace_config.json") - # 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化 - - config_json = self._load_config() - - if config_json: - - _config_factory = ConfigBasicFactory() - if config_json.get("log_verbose"): - _config_factory.log_verbose(config_json.get("log_verbose")) - if config_json.get("DATA_PATH"): - _config_factory.data_path(config_json.get("DATA_PATH")) - if config_json.get("LOG_FORMAT"): - _config_factory.log_format(config_json.get("LOG_FORMAT")) - - self._config_factory = _config_factory + super().__init__() def get_config(self) -> ConfigBasic: return self._config_factory.get_config() def set_log_verbose(self, verbose: bool): self._config_factory.log_verbose(verbose) - self._store_config() + self.store_config() def set_data_path(self, path: str): self._config_factory.data_path(path) - self._store_config() + self.store_config() def set_log_format(self, log_format: str): self._config_factory.log_format(log_format) - self._store_config() + self.store_config() def clear(self): logger.info("Clear workspace config.") os.remove(self.workspace_config) - def _load_config(self): - try: - with open(self.workspace_config, "r") as f: - return json.loads(f.read()) - except FileNotFoundError: - return None - def _store_config(self): - with open(self.workspace_config, "w") as f: - config = self._config_factory.get_config() - config_json = { - "log_verbose": config.log_verbose, - "CHATCHAT_ROOT": config.CHATCHAT_ROOT, - "DATA_PATH": config.DATA_PATH, - "LOG_FORMAT": config.LOG_FORMAT - } - f.write(json.dumps(config_json, indent=4, ensure_ascii=False)) - - -config_workspace: ConfigWorkSpace = ConfigWorkSpace() +config_basic_workspace: ConfigBasicWorkSpace = ConfigBasicWorkSpace() diff --git a/libs/chatchat-server/chatchat/configs/_core_config.py b/libs/chatchat-server/chatchat/configs/_core_config.py new file mode 100644 index 00000000..3977af2e --- /dev/null +++ b/libs/chatchat-server/chatchat/configs/_core_config.py @@ -0,0 +1,106 @@ +import os +import json +from abc import abstractmethod, ABC +from dataclasses import dataclass +from pathlib import Path +import logging +from typing import Any, Dict, TypeVar, Generic, Optional, Type + +from dataclasses_json import DataClassJsonMixin +from pydantic import BaseModel + +logger = logging.getLogger() + + +class Config(BaseModel): + @classmethod + @abstractmethod + def class_name(cls) -> str: + """Get class name.""" + + def to_dict(self, **kwargs: Any) -> Dict[str, Any]: + data = self.dict(**kwargs) + data["class_name"] = self.class_name() + return data + + def to_json(self, **kwargs: Any) -> str: + data = self.to_dict(**kwargs) + return json.dumps(data) + + +F = TypeVar("F", bound=Config) + + +@dataclass +class ConfigFactory(Generic[F], DataClassJsonMixin): + """config for ChatChat """ + + @classmethod + @abstractmethod + def get_config(cls) -> F: + raise NotImplementedError + + +CF = TypeVar("CF", bound=ConfigFactory) + + +class ConfigWorkSpace(Generic[CF, F], ABC): + """ + ConfigWorkSpace是一个配置工作空间的抽象类,提供基础的配置信息存储和读取功能。 + 提供ConfigFactory建造方法产生实例。 + 该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等 + 工作空间的配置信息存储在用户的家目录下的.chatchat/workspace/workspace_config.json文件中。 + 注意:不存在则读取默认 + """ + config_factory_cls: Type[CF] + _config_factory: Optional[CF] = None + + def __init__(self): + self.workspace = os.path.join(os.path.expanduser("~"), ".chatchat", "workspace") + if not os.path.exists(self.workspace): + os.makedirs(self.workspace, exist_ok=True) + self.workspace_config = os.path.join(self.workspace, "workspace_config.json") + # 初始化工作空间配置,转换成json格式,实现Config的实例化 + + config_type_json = self._load_config() + if config_type_json is None: + self._config_factory = self._build_config_factory(config_json={}) + self.store_config() + + else: + config_type = config_type_json.get("type", None) + if self.get_type() != config_type: + raise ValueError(f"Config type mismatch: {self.get_type()} != {config_type}") + + config_json = config_type_json.get("config") + self._config_factory = self._build_config_factory(config_json) + + @abstractmethod + def _build_config_factory(self, config_json: Any) -> CF: + raise NotImplementedError + + @classmethod + @abstractmethod + def get_type(cls) -> str: + raise NotImplementedError + + def get_config(self) -> F: + return self._config_factory.get_config() + + def clear(self): + logger.info("Clear workspace config.") + os.remove(self.workspace_config) + + def _load_config(self): + try: + with open(self.workspace_config, "r") as f: + return json.loads(f.read()) + except FileNotFoundError: + return None + + def store_config(self): + logger.info("Store workspace config.") + with open(self.workspace_config, "w") as f: + config_json = self.get_config().to_dict() + config_type_json = {"type": self.get_type(), "config": config_json} + f.write(json.dumps(config_type_json, indent=4, ensure_ascii=False)) diff --git a/libs/chatchat-server/chatchat/configs/_kb_config.py b/libs/chatchat-server/chatchat/configs/_kb_config.py index 40b4626b..a13eb87b 100644 --- a/libs/chatchat-server/chatchat/configs/_kb_config.py +++ b/libs/chatchat-server/chatchat/configs/_kb_config.py @@ -4,10 +4,10 @@ from pathlib import Path import sys sys.path.append(str(Path(__file__).parent)) -from _basic_config import config_workspace +from _basic_config import config_basic_workspace # 用户数据根目录 -DATA_PATH = config_workspace.get_config().DATA_PATH +DATA_PATH = config_basic_workspace.get_config().DATA_PATH # 默认使用的知识库 DEFAULT_KNOWLEDGE_BASE = "samples" diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py index 96fbd5dc..c748b827 100644 --- a/libs/chatchat-server/tests/unit_tests/config/test_config.py +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -1,36 +1,18 @@ from pathlib import Path -from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace +from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigBasicWorkSpace import os -def test_config_factory_def(): - test_config_factory = ConfigBasicFactory() - config: ConfigBasic = test_config_factory.get_config() - assert config is not None - assert config.log_verbose is False - assert config.CHATCHAT_ROOT is not None - assert config.DATA_PATH is not None - assert config.IMG_DIR is not None - assert config.NLTK_DATA_PATH is not None - assert config.LOG_FORMAT is not None - assert config.LOG_PATH is not None - assert config.MEDIA_PATH is not None - - assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) - assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) - assert os.path.exists(os.path.join(config.MEDIA_PATH, "video")) - - -def test_workspace(): - config_workspace = ConfigWorkSpace() - assert config_workspace.get_config() is not None +def test_config_basic_workspace(): + config_basic_workspace: ConfigBasicWorkSpace = ConfigBasicWorkSpace() + assert config_basic_workspace.get_config() is not None base_root = os.path.join(Path(__file__).absolute().parent, "chatchat") - config_workspace.set_data_path(os.path.join(base_root, "data")) - config_workspace.set_log_verbose(True) - config_workspace.set_log_format(" %(message)s") + config_basic_workspace.set_data_path(os.path.join(base_root, "data")) + config_basic_workspace.set_log_verbose(True) + config_basic_workspace.set_log_format(" %(message)s") - config: ConfigBasic = config_workspace.get_config() + config: ConfigBasic = config_basic_workspace.get_config() assert config.log_verbose is True assert config.DATA_PATH == os.path.join(base_root, "data") assert config.IMG_DIR is not None @@ -42,7 +24,7 @@ def test_workspace(): assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) assert os.path.exists(os.path.join(config.MEDIA_PATH, "video")) - config_workspace.clear() + config_basic_workspace.clear() def test_workspace_default():