ConfigWorkSpace是一个配置工作空间的抽象类,提供基础的配置信息存储和读取功能。

This commit is contained in:
glide-the 2024-06-10 22:03:38 +08:00
parent 72b1cab89a
commit 30e27a5320
5 changed files with 198 additions and 116 deletions

View File

@ -112,92 +112,95 @@ def _import_ConfigBasicFactory() -> Any:
return ConfigBasicFactory
def _import_ConfigWorkSpace() -> Any:
def _import_ConfigBasicWorkSpace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace")
ConfigBasicWorkSpace = load_mod(basic_config_load.get("module"), "ConfigBasicWorkSpace")
return ConfigWorkSpace
return ConfigBasicWorkSpace
def _import_config_workspace() -> Any:
def _import_config_basic_workspace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_basic_workspace
def _import_log_verbose() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().log_verbose
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_basic_workspace.get_config().log_verbose
def _import_chatchat_root() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().CHATCHAT_ROOT
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_basic_workspace.get_config().CHATCHAT_ROOT
def _import_data_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().DATA_PATH
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_basic_workspace.get_config().DATA_PATH
def _import_img_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().IMG_DIR
return config_basic_workspace.get_config().IMG_DIR
def _import_nltk_data_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().NLTK_DATA_PATH
return config_basic_workspace.get_config().NLTK_DATA_PATH
def _import_log_format() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().LOG_FORMAT
return config_basic_workspace.get_config().LOG_FORMAT
def _import_log_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().LOG_PATH
return config_basic_workspace.get_config().LOG_PATH
def _import_media_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().MEDIA_PATH
return config_basic_workspace.get_config().MEDIA_PATH
def _import_base_temp_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().BASE_TEMP_DIR
config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_basic_workspace.get_config().BASE_TEMP_DIR
def _import_default_knowledge_base() -> Any:
@ -517,10 +520,10 @@ def __getattr__(name: str) -> Any:
return _import_ConfigBasic()
elif name == "ConfigBasicFactory":
return _import_ConfigBasicFactory()
elif name == "ConfigWorkSpace":
return _import_ConfigWorkSpace()
elif name == "config_workspace":
return _import_config_workspace()
elif name == "ConfigBasicWorkSpace":
return _import_ConfigBasicWorkSpace()
elif name == "config_basic_workspace":
return _import_config_basic_workspace()
elif name == "log_verbose":
return _import_log_verbose()
elif name == "CHATCHAT_ROOT":
@ -621,7 +624,7 @@ VERSION = "v0.3.0-preview"
__all__ = [
"VERSION",
"config_workspace",
"config_basic_workspace",
"log_verbose",
"CHATCHAT_ROOT",
"DATA_PATH",
@ -672,6 +675,6 @@ __all__ = [
"ConfigBasic",
"ConfigBasicFactory",
"ConfigWorkSpace",
"ConfigBasicWorkSpace",
]

View File

@ -1,36 +1,49 @@
import os
import json
from dataclasses import dataclass
from pathlib import Path
import sys
import logging
from typing import Any, Optional
from chatchat.configs._core_config import CF
sys.path.append(str(Path(__file__).parent))
import _core_config as core_config
logger = logging.getLogger()
class ConfigBasic:
log_verbose: bool
class ConfigBasic(core_config.Config):
log_verbose: Optional[bool] = None
"""是否开启日志详细信息"""
CHATCHAT_ROOT: str
CHATCHAT_ROOT: Optional[str] = None
"""项目根目录"""
DATA_PATH: str
DATA_PATH: Optional[str] = None
"""用户数据根目录"""
IMG_DIR: str
IMG_DIR: Optional[str] = None
"""项目相关图片"""
NLTK_DATA_PATH: str
NLTK_DATA_PATH: Optional[str] = None
"""nltk 模型存储路径"""
LOG_FORMAT: str
LOG_FORMAT: Optional[str] = None
"""日志格式"""
LOG_PATH: str
LOG_PATH: Optional[str] = None
"""日志存储路径"""
MEDIA_PATH: str
MEDIA_PATH: Optional[str] = None
"""模型生成内容(图片、视频、音频等)保存位置"""
BASE_TEMP_DIR: str
BASE_TEMP_DIR: Optional[str] = None
"""临时文件目录,主要用于文件对话"""
@classmethod
def class_name(cls) -> str:
return cls.__name__
def __str__(self):
return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})"
class ConfigBasicFactory:
@dataclass
class ConfigBasicFactory(core_config.ConfigFactory[ConfigBasic]):
"""Basic config for ChatChat """
def __init__(self):
@ -109,72 +122,50 @@ class ConfigBasicFactory:
return config
class ConfigWorkSpace:
class ConfigBasicWorkSpace(core_config.ConfigWorkSpace[ConfigBasicFactory, ConfigBasic]):
"""
工作空间的配置预设提供ConfigBasic建造方法产生实例
该类的实例对象用于存储工作空间的配置信息如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中
注意不存在则读取默认
"""
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
config_factory_cls = ConfigBasicFactory
def _build_config_factory(self, config_json: Any) -> ConfigBasicFactory:
_config_factory = self.config_factory_cls()
if config_json.get("log_verbose"):
_config_factory.log_verbose(config_json.get("log_verbose"))
if config_json.get("DATA_PATH"):
_config_factory.data_path(config_json.get("DATA_PATH"))
if config_json.get("LOG_FORMAT"):
_config_factory.log_format(config_json.get("LOG_FORMAT"))
return _config_factory
@classmethod
def get_type(cls) -> str:
return ConfigBasic.class_name()
def __init__(self):
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
if not os.path.exists(self.workspace):
os.makedirs(self.workspace, exist_ok=True)
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
# 初始化工作空间配置转换成json格式实现ConfigBasic的实例化
config_json = self._load_config()
if config_json:
_config_factory = ConfigBasicFactory()
if config_json.get("log_verbose"):
_config_factory.log_verbose(config_json.get("log_verbose"))
if config_json.get("DATA_PATH"):
_config_factory.data_path(config_json.get("DATA_PATH"))
if config_json.get("LOG_FORMAT"):
_config_factory.log_format(config_json.get("LOG_FORMAT"))
self._config_factory = _config_factory
super().__init__()
def get_config(self) -> ConfigBasic:
return self._config_factory.get_config()
def set_log_verbose(self, verbose: bool):
self._config_factory.log_verbose(verbose)
self._store_config()
self.store_config()
def set_data_path(self, path: str):
self._config_factory.data_path(path)
self._store_config()
self.store_config()
def set_log_format(self, log_format: str):
self._config_factory.log_format(log_format)
self._store_config()
self.store_config()
def clear(self):
logger.info("Clear workspace config.")
os.remove(self.workspace_config)
def _load_config(self):
try:
with open(self.workspace_config, "r") as f:
return json.loads(f.read())
except FileNotFoundError:
return None
def _store_config(self):
with open(self.workspace_config, "w") as f:
config = self._config_factory.get_config()
config_json = {
"log_verbose": config.log_verbose,
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
"DATA_PATH": config.DATA_PATH,
"LOG_FORMAT": config.LOG_FORMAT
}
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
config_workspace: ConfigWorkSpace = ConfigWorkSpace()
config_basic_workspace: ConfigBasicWorkSpace = ConfigBasicWorkSpace()

View File

@ -0,0 +1,106 @@
import os
import json
from abc import abstractmethod, ABC
from dataclasses import dataclass
from pathlib import Path
import logging
from typing import Any, Dict, TypeVar, Generic, Optional, Type
from dataclasses_json import DataClassJsonMixin
from pydantic import BaseModel
logger = logging.getLogger()
class Config(BaseModel):
@classmethod
@abstractmethod
def class_name(cls) -> str:
"""Get class name."""
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
data = self.dict(**kwargs)
data["class_name"] = self.class_name()
return data
def to_json(self, **kwargs: Any) -> str:
data = self.to_dict(**kwargs)
return json.dumps(data)
F = TypeVar("F", bound=Config)
@dataclass
class ConfigFactory(Generic[F], DataClassJsonMixin):
"""config for ChatChat """
@classmethod
@abstractmethod
def get_config(cls) -> F:
raise NotImplementedError
CF = TypeVar("CF", bound=ConfigFactory)
class ConfigWorkSpace(Generic[CF, F], ABC):
"""
ConfigWorkSpace是一个配置工作空间的抽象类提供基础的配置信息存储和读取功能
提供ConfigFactory建造方法产生实例
该类的实例对象用于存储工作空间的配置信息如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.chatchat/workspace/workspace_config.json文件中
注意不存在则读取默认
"""
config_factory_cls: Type[CF]
_config_factory: Optional[CF] = None
def __init__(self):
self.workspace = os.path.join(os.path.expanduser("~"), ".chatchat", "workspace")
if not os.path.exists(self.workspace):
os.makedirs(self.workspace, exist_ok=True)
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
# 初始化工作空间配置转换成json格式实现Config的实例化
config_type_json = self._load_config()
if config_type_json is None:
self._config_factory = self._build_config_factory(config_json={})
self.store_config()
else:
config_type = config_type_json.get("type", None)
if self.get_type() != config_type:
raise ValueError(f"Config type mismatch: {self.get_type()} != {config_type}")
config_json = config_type_json.get("config")
self._config_factory = self._build_config_factory(config_json)
@abstractmethod
def _build_config_factory(self, config_json: Any) -> CF:
raise NotImplementedError
@classmethod
@abstractmethod
def get_type(cls) -> str:
raise NotImplementedError
def get_config(self) -> F:
return self._config_factory.get_config()
def clear(self):
logger.info("Clear workspace config.")
os.remove(self.workspace_config)
def _load_config(self):
try:
with open(self.workspace_config, "r") as f:
return json.loads(f.read())
except FileNotFoundError:
return None
def store_config(self):
logger.info("Store workspace config.")
with open(self.workspace_config, "w") as f:
config_json = self.get_config().to_dict()
config_type_json = {"type": self.get_type(), "config": config_json}
f.write(json.dumps(config_type_json, indent=4, ensure_ascii=False))

View File

@ -4,10 +4,10 @@ from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent))
from _basic_config import config_workspace
from _basic_config import config_basic_workspace
# 用户数据根目录
DATA_PATH = config_workspace.get_config().DATA_PATH
DATA_PATH = config_basic_workspace.get_config().DATA_PATH
# 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples"

View File

@ -1,36 +1,18 @@
from pathlib import Path
from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace
from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigBasicWorkSpace
import os
def test_config_factory_def():
test_config_factory = ConfigBasicFactory()
config: ConfigBasic = test_config_factory.get_config()
assert config is not None
assert config.log_verbose is False
assert config.CHATCHAT_ROOT is not None
assert config.DATA_PATH is not None
assert config.IMG_DIR is not None
assert config.NLTK_DATA_PATH is not None
assert config.LOG_FORMAT is not None
assert config.LOG_PATH is not None
assert config.MEDIA_PATH is not None
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
def test_workspace():
config_workspace = ConfigWorkSpace()
assert config_workspace.get_config() is not None
def test_config_basic_workspace():
config_basic_workspace: ConfigBasicWorkSpace = ConfigBasicWorkSpace()
assert config_basic_workspace.get_config() is not None
base_root = os.path.join(Path(__file__).absolute().parent, "chatchat")
config_workspace.set_data_path(os.path.join(base_root, "data"))
config_workspace.set_log_verbose(True)
config_workspace.set_log_format(" %(message)s")
config_basic_workspace.set_data_path(os.path.join(base_root, "data"))
config_basic_workspace.set_log_verbose(True)
config_basic_workspace.set_log_format(" %(message)s")
config: ConfigBasic = config_workspace.get_config()
config: ConfigBasic = config_basic_workspace.get_config()
assert config.log_verbose is True
assert config.DATA_PATH == os.path.join(base_root, "data")
assert config.IMG_DIR is not None
@ -42,7 +24,7 @@ def test_workspace():
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
config_workspace.clear()
config_basic_workspace.clear()
def test_workspace_default():