ConfigWorkSpace是一个配置工作空间的抽象类,提供基础的配置信息存储和读取功能。

This commit is contained in:
glide-the 2024-06-10 22:03:38 +08:00
parent 72b1cab89a
commit 30e27a5320
5 changed files with 198 additions and 116 deletions

View File

@ -112,92 +112,95 @@ def _import_ConfigBasicFactory() -> Any:
return ConfigBasicFactory return ConfigBasicFactory
def _import_ConfigWorkSpace() -> Any: def _import_ConfigBasicWorkSpace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace") ConfigBasicWorkSpace = load_mod(basic_config_load.get("module"), "ConfigBasicWorkSpace")
return ConfigWorkSpace return ConfigBasicWorkSpace
def _import_config_workspace() -> Any: def _import_config_basic_workspace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace return config_basic_workspace
def _import_log_verbose() -> Any: def _import_log_verbose() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().log_verbose config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_basic_workspace.get_config().log_verbose
def _import_chatchat_root() -> Any: def _import_chatchat_root() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().CHATCHAT_ROOT config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_basic_workspace.get_config().CHATCHAT_ROOT
def _import_data_path() -> Any: def _import_data_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().DATA_PATH return config_basic_workspace.get_config().DATA_PATH
def _import_img_dir() -> Any: def _import_img_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().IMG_DIR return config_basic_workspace.get_config().IMG_DIR
def _import_nltk_data_path() -> Any: def _import_nltk_data_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().NLTK_DATA_PATH return config_basic_workspace.get_config().NLTK_DATA_PATH
def _import_log_format() -> Any: def _import_log_format() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().LOG_FORMAT return config_basic_workspace.get_config().LOG_FORMAT
def _import_log_path() -> Any: def _import_log_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().LOG_PATH return config_basic_workspace.get_config().LOG_PATH
def _import_media_path() -> Any: def _import_media_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().MEDIA_PATH return config_basic_workspace.get_config().MEDIA_PATH
def _import_base_temp_dir() -> Any: def _import_base_temp_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") config_basic_workspace = load_mod(basic_config_load.get("module"), "config_basic_workspace")
return config_workspace.get_config().BASE_TEMP_DIR return config_basic_workspace.get_config().BASE_TEMP_DIR
def _import_default_knowledge_base() -> Any: def _import_default_knowledge_base() -> Any:
@ -517,10 +520,10 @@ def __getattr__(name: str) -> Any:
return _import_ConfigBasic() return _import_ConfigBasic()
elif name == "ConfigBasicFactory": elif name == "ConfigBasicFactory":
return _import_ConfigBasicFactory() return _import_ConfigBasicFactory()
elif name == "ConfigWorkSpace": elif name == "ConfigBasicWorkSpace":
return _import_ConfigWorkSpace() return _import_ConfigBasicWorkSpace()
elif name == "config_workspace": elif name == "config_basic_workspace":
return _import_config_workspace() return _import_config_basic_workspace()
elif name == "log_verbose": elif name == "log_verbose":
return _import_log_verbose() return _import_log_verbose()
elif name == "CHATCHAT_ROOT": elif name == "CHATCHAT_ROOT":
@ -621,7 +624,7 @@ VERSION = "v0.3.0-preview"
__all__ = [ __all__ = [
"VERSION", "VERSION",
"config_workspace", "config_basic_workspace",
"log_verbose", "log_verbose",
"CHATCHAT_ROOT", "CHATCHAT_ROOT",
"DATA_PATH", "DATA_PATH",
@ -672,6 +675,6 @@ __all__ = [
"ConfigBasic", "ConfigBasic",
"ConfigBasicFactory", "ConfigBasicFactory",
"ConfigWorkSpace", "ConfigBasicWorkSpace",
] ]

View File

@ -1,36 +1,49 @@
import os import os
import json import json
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import sys
import logging import logging
from typing import Any, Optional
from chatchat.configs._core_config import CF
sys.path.append(str(Path(__file__).parent))
import _core_config as core_config
logger = logging.getLogger() logger = logging.getLogger()
class ConfigBasic: class ConfigBasic(core_config.Config):
log_verbose: bool log_verbose: Optional[bool] = None
"""是否开启日志详细信息""" """是否开启日志详细信息"""
CHATCHAT_ROOT: str CHATCHAT_ROOT: Optional[str] = None
"""项目根目录""" """项目根目录"""
DATA_PATH: str DATA_PATH: Optional[str] = None
"""用户数据根目录""" """用户数据根目录"""
IMG_DIR: str IMG_DIR: Optional[str] = None
"""项目相关图片""" """项目相关图片"""
NLTK_DATA_PATH: str NLTK_DATA_PATH: Optional[str] = None
"""nltk 模型存储路径""" """nltk 模型存储路径"""
LOG_FORMAT: str LOG_FORMAT: Optional[str] = None
"""日志格式""" """日志格式"""
LOG_PATH: str LOG_PATH: Optional[str] = None
"""日志存储路径""" """日志存储路径"""
MEDIA_PATH: str MEDIA_PATH: Optional[str] = None
"""模型生成内容(图片、视频、音频等)保存位置""" """模型生成内容(图片、视频、音频等)保存位置"""
BASE_TEMP_DIR: str BASE_TEMP_DIR: Optional[str] = None
"""临时文件目录,主要用于文件对话""" """临时文件目录,主要用于文件对话"""
@classmethod
def class_name(cls) -> str:
return cls.__name__
def __str__(self): def __str__(self):
return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})" return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})"
class ConfigBasicFactory: @dataclass
class ConfigBasicFactory(core_config.ConfigFactory[ConfigBasic]):
"""Basic config for ChatChat """ """Basic config for ChatChat """
def __init__(self): def __init__(self):
@ -109,72 +122,50 @@ class ConfigBasicFactory:
return config return config
class ConfigWorkSpace: class ConfigBasicWorkSpace(core_config.ConfigWorkSpace[ConfigBasicFactory, ConfigBasic]):
""" """
工作空间的配置预设提供ConfigBasic建造方法产生实例 工作空间的配置预设提供ConfigBasic建造方法产生实例
该类的实例对象用于存储工作空间的配置信息如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中
注意不存在则读取默认
""" """
_config_factory: ConfigBasicFactory = ConfigBasicFactory() config_factory_cls = ConfigBasicFactory
def _build_config_factory(self, config_json: Any) -> ConfigBasicFactory:
_config_factory = self.config_factory_cls()
if config_json.get("log_verbose"):
_config_factory.log_verbose(config_json.get("log_verbose"))
if config_json.get("DATA_PATH"):
_config_factory.data_path(config_json.get("DATA_PATH"))
if config_json.get("LOG_FORMAT"):
_config_factory.log_format(config_json.get("LOG_FORMAT"))
return _config_factory
@classmethod
def get_type(cls) -> str:
return ConfigBasic.class_name()
def __init__(self): def __init__(self):
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace") super().__init__()
if not os.path.exists(self.workspace):
os.makedirs(self.workspace, exist_ok=True)
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
# 初始化工作空间配置转换成json格式实现ConfigBasic的实例化
config_json = self._load_config()
if config_json:
_config_factory = ConfigBasicFactory()
if config_json.get("log_verbose"):
_config_factory.log_verbose(config_json.get("log_verbose"))
if config_json.get("DATA_PATH"):
_config_factory.data_path(config_json.get("DATA_PATH"))
if config_json.get("LOG_FORMAT"):
_config_factory.log_format(config_json.get("LOG_FORMAT"))
self._config_factory = _config_factory
def get_config(self) -> ConfigBasic: def get_config(self) -> ConfigBasic:
return self._config_factory.get_config() return self._config_factory.get_config()
def set_log_verbose(self, verbose: bool): def set_log_verbose(self, verbose: bool):
self._config_factory.log_verbose(verbose) self._config_factory.log_verbose(verbose)
self._store_config() self.store_config()
def set_data_path(self, path: str): def set_data_path(self, path: str):
self._config_factory.data_path(path) self._config_factory.data_path(path)
self._store_config() self.store_config()
def set_log_format(self, log_format: str): def set_log_format(self, log_format: str):
self._config_factory.log_format(log_format) self._config_factory.log_format(log_format)
self._store_config() self.store_config()
def clear(self): def clear(self):
logger.info("Clear workspace config.") logger.info("Clear workspace config.")
os.remove(self.workspace_config) os.remove(self.workspace_config)
def _load_config(self):
try:
with open(self.workspace_config, "r") as f:
return json.loads(f.read())
except FileNotFoundError:
return None
def _store_config(self): config_basic_workspace: ConfigBasicWorkSpace = ConfigBasicWorkSpace()
with open(self.workspace_config, "w") as f:
config = self._config_factory.get_config()
config_json = {
"log_verbose": config.log_verbose,
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
"DATA_PATH": config.DATA_PATH,
"LOG_FORMAT": config.LOG_FORMAT
}
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
config_workspace: ConfigWorkSpace = ConfigWorkSpace()

View File

@ -0,0 +1,106 @@
import os
import json
from abc import abstractmethod, ABC
from dataclasses import dataclass
from pathlib import Path
import logging
from typing import Any, Dict, TypeVar, Generic, Optional, Type
from dataclasses_json import DataClassJsonMixin
from pydantic import BaseModel
logger = logging.getLogger()
class Config(BaseModel):
@classmethod
@abstractmethod
def class_name(cls) -> str:
"""Get class name."""
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
data = self.dict(**kwargs)
data["class_name"] = self.class_name()
return data
def to_json(self, **kwargs: Any) -> str:
data = self.to_dict(**kwargs)
return json.dumps(data)
F = TypeVar("F", bound=Config)
@dataclass
class ConfigFactory(Generic[F], DataClassJsonMixin):
"""config for ChatChat """
@classmethod
@abstractmethod
def get_config(cls) -> F:
raise NotImplementedError
CF = TypeVar("CF", bound=ConfigFactory)
class ConfigWorkSpace(Generic[CF, F], ABC):
"""
ConfigWorkSpace是一个配置工作空间的抽象类提供基础的配置信息存储和读取功能
提供ConfigFactory建造方法产生实例
该类的实例对象用于存储工作空间的配置信息如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.chatchat/workspace/workspace_config.json文件中
注意不存在则读取默认
"""
config_factory_cls: Type[CF]
_config_factory: Optional[CF] = None
def __init__(self):
self.workspace = os.path.join(os.path.expanduser("~"), ".chatchat", "workspace")
if not os.path.exists(self.workspace):
os.makedirs(self.workspace, exist_ok=True)
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
# 初始化工作空间配置转换成json格式实现Config的实例化
config_type_json = self._load_config()
if config_type_json is None:
self._config_factory = self._build_config_factory(config_json={})
self.store_config()
else:
config_type = config_type_json.get("type", None)
if self.get_type() != config_type:
raise ValueError(f"Config type mismatch: {self.get_type()} != {config_type}")
config_json = config_type_json.get("config")
self._config_factory = self._build_config_factory(config_json)
@abstractmethod
def _build_config_factory(self, config_json: Any) -> CF:
raise NotImplementedError
@classmethod
@abstractmethod
def get_type(cls) -> str:
raise NotImplementedError
def get_config(self) -> F:
return self._config_factory.get_config()
def clear(self):
logger.info("Clear workspace config.")
os.remove(self.workspace_config)
def _load_config(self):
try:
with open(self.workspace_config, "r") as f:
return json.loads(f.read())
except FileNotFoundError:
return None
def store_config(self):
logger.info("Store workspace config.")
with open(self.workspace_config, "w") as f:
config_json = self.get_config().to_dict()
config_type_json = {"type": self.get_type(), "config": config_json}
f.write(json.dumps(config_type_json, indent=4, ensure_ascii=False))

View File

@ -4,10 +4,10 @@ from pathlib import Path
import sys import sys
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
from _basic_config import config_workspace from _basic_config import config_basic_workspace
# 用户数据根目录 # 用户数据根目录
DATA_PATH = config_workspace.get_config().DATA_PATH DATA_PATH = config_basic_workspace.get_config().DATA_PATH
# 默认使用的知识库 # 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples" DEFAULT_KNOWLEDGE_BASE = "samples"

View File

@ -1,36 +1,18 @@
from pathlib import Path from pathlib import Path
from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigBasicWorkSpace
import os import os
def test_config_factory_def(): def test_config_basic_workspace():
test_config_factory = ConfigBasicFactory() config_basic_workspace: ConfigBasicWorkSpace = ConfigBasicWorkSpace()
config: ConfigBasic = test_config_factory.get_config() assert config_basic_workspace.get_config() is not None
assert config is not None
assert config.log_verbose is False
assert config.CHATCHAT_ROOT is not None
assert config.DATA_PATH is not None
assert config.IMG_DIR is not None
assert config.NLTK_DATA_PATH is not None
assert config.LOG_FORMAT is not None
assert config.LOG_PATH is not None
assert config.MEDIA_PATH is not None
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
def test_workspace():
config_workspace = ConfigWorkSpace()
assert config_workspace.get_config() is not None
base_root = os.path.join(Path(__file__).absolute().parent, "chatchat") base_root = os.path.join(Path(__file__).absolute().parent, "chatchat")
config_workspace.set_data_path(os.path.join(base_root, "data")) config_basic_workspace.set_data_path(os.path.join(base_root, "data"))
config_workspace.set_log_verbose(True) config_basic_workspace.set_log_verbose(True)
config_workspace.set_log_format(" %(message)s") config_basic_workspace.set_log_format(" %(message)s")
config: ConfigBasic = config_workspace.get_config() config: ConfigBasic = config_basic_workspace.get_config()
assert config.log_verbose is True assert config.log_verbose is True
assert config.DATA_PATH == os.path.join(base_root, "data") assert config.DATA_PATH == os.path.join(base_root, "data")
assert config.IMG_DIR is not None assert config.IMG_DIR is not None
@ -42,7 +24,7 @@ def test_workspace():
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video")) assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
config_workspace.clear() config_basic_workspace.clear()
def test_workspace_default(): def test_workspace_default():