pytest config

This commit is contained in:
glide-the 2024-06-09 13:41:57 +08:00
parent b1c5bf9c94
commit e4de0ceabc
3 changed files with 275 additions and 2 deletions

View File

@ -1,13 +1,171 @@
import importlib import importlib
import importlib.util import importlib.util
import os import os
from pathlib import Path
from typing import Dict, Any from typing import Dict, Any
import json
import logging import logging
logger = logging.getLogger() logger = logging.getLogger()
class ConfigBasic:
log_verbose: bool
CHATCHAT_ROOT: str
DATA_PATH: str
IMG_DIR: str
NLTK_DATA_PATH: str
LOG_FORMAT: str
LOG_PATH: str
MEDIA_PATH: str
BASE_TEMP_DIR: str
class ConfigBasicFactory:
"""Basic config for ChatChat """
def __init__(self):
# 日志格式
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logging.basicConfig(format=self.LOG_FORMAT)
self.LOG_VERBOSE = False
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
# 用户数据根目录
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
if not os.path.exists(self._DATA_PATH):
os.mkdir(self.DATA_PATH)
self._init_data_dir()
# 项目相关图片
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
if not os.path.exists(self.IMG_DIR):
os.mkdir(self.IMG_DIR)
def log_verbose(self, verbose: bool):
self.LOG_VERBOSE = verbose
def chatchat_root(self, root: str):
self.CHATCHAT_ROOT = root
def data_path(self, path: str):
self.DATA_PATH = path
if not os.path.exists(self.DATA_PATH):
os.mkdir(self.DATA_PATH)
# 复制_DATA_PATH数据到DATA_PATH
os.system(f"cp -r {self._DATA_PATH} {self.DATA_PATH}")
self._init_data_dir()
def log_format(self, log_format: str):
self.LOG_FORMAT = log_format
logging.basicConfig(format=self.LOG_FORMAT)
def _init_data_dir(self):
logger.info(f"Init data dir: {self.DATA_PATH}")
# nltk 模型存储路径
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
import nltk
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
# 日志存储路径
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
if not os.path.exists(self.LOG_PATH):
os.mkdir(self.LOG_PATH)
# 模型生成内容(图片、视频、音频等)保存位置
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
if not os.path.exists(self.MEDIA_PATH):
os.mkdir(self.MEDIA_PATH)
os.mkdir(os.path.join(self.MEDIA_PATH, "image"))
os.mkdir(os.path.join(self.MEDIA_PATH, "audio"))
os.mkdir(os.path.join(self.MEDIA_PATH, "video"))
# 临时文件目录,主要用于文件对话
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
if not os.path.exists(self.BASE_TEMP_DIR):
os.mkdir(self.BASE_TEMP_DIR)
logger.info(f"Init data dir: {self.DATA_PATH} success.")
def get_config(self) -> ConfigBasic:
config = ConfigBasic()
config.log_verbose = self.LOG_VERBOSE
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
config.DATA_PATH = self.DATA_PATH
config.IMG_DIR = self.IMG_DIR
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
config.LOG_FORMAT = self.LOG_FORMAT
config.LOG_PATH = self.LOG_PATH
config.MEDIA_PATH = self.MEDIA_PATH
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
return config
class ConfigWorkSpace:
"""
工作空间的配置预设提供ConfigBasic建造方法产生实例
该类的实例对象用于存储工作空间的配置信息如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中
注意不存在则读取默认
"""
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
def __init__(self):
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
if not os.path.exists(self.workspace):
os.makedirs(self.workspace, exist_ok=True)
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
# 初始化工作空间配置转换成json格式实现ConfigBasic的实例化
with open(self.workspace_config, "w") as f:
config_json = json.loads(f.read())
if config_json:
_config_factory = ConfigBasicFactory()
if config_json.get("log_verbose"):
_config_factory.log_verbose(config_json.get("log_verbose"))
if config_json.get("CHATCHAT_ROOT"):
_config_factory.chatchat_root(config_json.get("CHATCHAT_ROOT"))
if config_json.get("DATA_PATH"):
_config_factory.data_path(config_json.get("DATA_PATH"))
if config_json.get("LOG_FORMAT"):
_config_factory.log_format(config_json.get("LOG_FORMAT"))
self._config_factory = _config_factory
def get_config(self) -> ConfigBasic:
return self._config_factory.get_config()
def set_log_verbose(self, verbose: bool):
self._config_factory.log_verbose(verbose)
self._store_config()
def set_chatchat_root(self, root: str):
self._config_factory.chatchat_root(root)
self._store_config()
def set_data_path(self, path: str):
self._config_factory.data_path(path)
self._store_config()
def set_log_format(self, log_format: str):
self._config_factory.log_format(log_format)
self._store_config()
def _store_config(self):
with open(self.workspace_config, "w") as f:
config = self._config_factory.get_config()
config_json = {
"log_verbose": config.log_verbose,
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
"DATA_PATH": config.DATA_PATH,
"LOG_FORMAT": config.LOG_FORMAT
}
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
def _load_mod(mod, attr): def _load_mod(mod, attr):
attr_cfg = None attr_cfg = None
for name, obj in vars(mod).items(): for name, obj in vars(mod).items():
@ -38,7 +196,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
) )
user_import = False user_import = False
if user_import: if user_import:
# Dynamic loading {config}.py file # Dynamic loading {config}.py file
py_path = os.path.join(user_config_path, import_config_mod + ".py") py_path = os.path.join(user_config_path, import_config_mod + ".py")
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
@ -69,7 +226,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
) )
raise RuntimeError(f"Failed to load user config from {user_config_path}") raise RuntimeError(f"Failed to load user config from {user_config_path}")
# 当前文件路径 # 当前文件路径
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py") py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
spec = importlib.util.spec_from_file_location(f"*", spec = importlib.util.spec_from_file_location(f"*",
py_path) py_path)
@ -118,6 +275,7 @@ def _import_data_path() -> Any:
return DATA_PATH return DATA_PATH
def _import_img_dir() -> Any: def _import_img_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
@ -285,6 +443,7 @@ def _import_db_root_path() -> Any:
return DB_ROOT_PATH return DB_ROOT_PATH
def _import_sqlalchemy_database_uri() -> Any: def _import_sqlalchemy_database_uri() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
@ -627,4 +786,8 @@ __all__ = [
"API_SERVER", "API_SERVER",
"ConfigBasic",
"ConfigBasicFactory",
"ConfigWorkSpace",
] ]

View File

@ -0,0 +1,90 @@
"""Configuration for unit tests."""
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)

View File

@ -0,0 +1,20 @@
from chatchat.configs import ConfigBasicFactory, ConfigBasic
import os
def test_config_factory_def():
test_config_factory = ConfigBasicFactory()
config: ConfigBasic = test_config_factory.get_config()
assert config is not None
assert config.log_verbose is False
assert config.CHATCHAT_ROOT is not None
assert config.DATA_PATH is not None
assert config.IMG_DIR is not None
assert config.NLTK_DATA_PATH is not None
assert config.LOG_FORMAT is not None
assert config.LOG_PATH is not None
assert config.MEDIA_PATH is not None
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))