pytest config

This commit is contained in:
glide-the 2024-06-09 13:41:57 +08:00
parent b1c5bf9c94
commit e4de0ceabc
3 changed files with 275 additions and 2 deletions

View File

@ -1,13 +1,171 @@
import importlib
import importlib.util
import os
from pathlib import Path
from typing import Dict, Any
import json
import logging
logger = logging.getLogger()
class ConfigBasic:
log_verbose: bool
CHATCHAT_ROOT: str
DATA_PATH: str
IMG_DIR: str
NLTK_DATA_PATH: str
LOG_FORMAT: str
LOG_PATH: str
MEDIA_PATH: str
BASE_TEMP_DIR: str
class ConfigBasicFactory:
"""Basic config for ChatChat """
def __init__(self):
# 日志格式
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logging.basicConfig(format=self.LOG_FORMAT)
self.LOG_VERBOSE = False
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
# 用户数据根目录
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
if not os.path.exists(self._DATA_PATH):
os.mkdir(self.DATA_PATH)
self._init_data_dir()
# 项目相关图片
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
if not os.path.exists(self.IMG_DIR):
os.mkdir(self.IMG_DIR)
def log_verbose(self, verbose: bool):
self.LOG_VERBOSE = verbose
def chatchat_root(self, root: str):
self.CHATCHAT_ROOT = root
def data_path(self, path: str):
self.DATA_PATH = path
if not os.path.exists(self.DATA_PATH):
os.mkdir(self.DATA_PATH)
# 复制_DATA_PATH数据到DATA_PATH
os.system(f"cp -r {self._DATA_PATH} {self.DATA_PATH}")
self._init_data_dir()
def log_format(self, log_format: str):
self.LOG_FORMAT = log_format
logging.basicConfig(format=self.LOG_FORMAT)
def _init_data_dir(self):
logger.info(f"Init data dir: {self.DATA_PATH}")
# nltk 模型存储路径
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
import nltk
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
# 日志存储路径
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
if not os.path.exists(self.LOG_PATH):
os.mkdir(self.LOG_PATH)
# 模型生成内容(图片、视频、音频等)保存位置
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
if not os.path.exists(self.MEDIA_PATH):
os.mkdir(self.MEDIA_PATH)
os.mkdir(os.path.join(self.MEDIA_PATH, "image"))
os.mkdir(os.path.join(self.MEDIA_PATH, "audio"))
os.mkdir(os.path.join(self.MEDIA_PATH, "video"))
# 临时文件目录,主要用于文件对话
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
if not os.path.exists(self.BASE_TEMP_DIR):
os.mkdir(self.BASE_TEMP_DIR)
logger.info(f"Init data dir: {self.DATA_PATH} success.")
def get_config(self) -> ConfigBasic:
config = ConfigBasic()
config.log_verbose = self.LOG_VERBOSE
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
config.DATA_PATH = self.DATA_PATH
config.IMG_DIR = self.IMG_DIR
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
config.LOG_FORMAT = self.LOG_FORMAT
config.LOG_PATH = self.LOG_PATH
config.MEDIA_PATH = self.MEDIA_PATH
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
return config
class ConfigWorkSpace:
"""
工作空间的配置预设提供ConfigBasic建造方法产生实例
该类的实例对象用于存储工作空间的配置信息如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中
注意不存在则读取默认
"""
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
def __init__(self):
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
if not os.path.exists(self.workspace):
os.makedirs(self.workspace, exist_ok=True)
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
# 初始化工作空间配置转换成json格式实现ConfigBasic的实例化
with open(self.workspace_config, "w") as f:
config_json = json.loads(f.read())
if config_json:
_config_factory = ConfigBasicFactory()
if config_json.get("log_verbose"):
_config_factory.log_verbose(config_json.get("log_verbose"))
if config_json.get("CHATCHAT_ROOT"):
_config_factory.chatchat_root(config_json.get("CHATCHAT_ROOT"))
if config_json.get("DATA_PATH"):
_config_factory.data_path(config_json.get("DATA_PATH"))
if config_json.get("LOG_FORMAT"):
_config_factory.log_format(config_json.get("LOG_FORMAT"))
self._config_factory = _config_factory
def get_config(self) -> ConfigBasic:
return self._config_factory.get_config()
def set_log_verbose(self, verbose: bool):
self._config_factory.log_verbose(verbose)
self._store_config()
def set_chatchat_root(self, root: str):
self._config_factory.chatchat_root(root)
self._store_config()
def set_data_path(self, path: str):
self._config_factory.data_path(path)
self._store_config()
def set_log_format(self, log_format: str):
self._config_factory.log_format(log_format)
self._store_config()
def _store_config(self):
with open(self.workspace_config, "w") as f:
config = self._config_factory.get_config()
config_json = {
"log_verbose": config.log_verbose,
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
"DATA_PATH": config.DATA_PATH,
"LOG_FORMAT": config.LOG_FORMAT
}
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
def _load_mod(mod, attr):
attr_cfg = None
for name, obj in vars(mod).items():
@ -38,7 +196,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
)
user_import = False
if user_import:
# Dynamic loading {config}.py file
py_path = os.path.join(user_config_path, import_config_mod + ".py")
spec = importlib.util.spec_from_file_location(
@ -69,7 +226,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
)
raise RuntimeError(f"Failed to load user config from {user_config_path}")
# 当前文件路径
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
spec = importlib.util.spec_from_file_location(f"*",
py_path)
@ -118,6 +275,7 @@ def _import_data_path() -> Any:
return DATA_PATH
def _import_img_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
@ -285,6 +443,7 @@ def _import_db_root_path() -> Any:
return DB_ROOT_PATH
def _import_sqlalchemy_database_uri() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
@ -627,4 +786,8 @@ __all__ = [
"API_SERVER",
"ConfigBasic",
"ConfigBasicFactory",
"ConfigWorkSpace",
]

View File

@ -0,0 +1,90 @@
"""Configuration for unit tests."""
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)

View File

@ -0,0 +1,20 @@
from chatchat.configs import ConfigBasicFactory, ConfigBasic
import os
def test_config_factory_def():
test_config_factory = ConfigBasicFactory()
config: ConfigBasic = test_config_factory.get_config()
assert config is not None
assert config.log_verbose is False
assert config.CHATCHAT_ROOT is not None
assert config.DATA_PATH is not None
assert config.IMG_DIR is not None
assert config.NLTK_DATA_PATH is not None
assert config.LOG_FORMAT is not None
assert config.LOG_PATH is not None
assert config.MEDIA_PATH is not None
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))