mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
pytest config
This commit is contained in:
parent
b1c5bf9c94
commit
e4de0ceabc
@ -1,13 +1,171 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class ConfigBasic:
|
||||
log_verbose: bool
|
||||
CHATCHAT_ROOT: str
|
||||
DATA_PATH: str
|
||||
IMG_DIR: str
|
||||
NLTK_DATA_PATH: str
|
||||
LOG_FORMAT: str
|
||||
LOG_PATH: str
|
||||
MEDIA_PATH: str
|
||||
BASE_TEMP_DIR: str
|
||||
|
||||
|
||||
class ConfigBasicFactory:
|
||||
"""Basic config for ChatChat """
|
||||
|
||||
def __init__(self):
|
||||
# 日志格式
|
||||
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logging.basicConfig(format=self.LOG_FORMAT)
|
||||
self.LOG_VERBOSE = False
|
||||
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
||||
# 用户数据根目录
|
||||
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||
if not os.path.exists(self._DATA_PATH):
|
||||
os.mkdir(self.DATA_PATH)
|
||||
|
||||
self._init_data_dir()
|
||||
|
||||
# 项目相关图片
|
||||
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
|
||||
if not os.path.exists(self.IMG_DIR):
|
||||
os.mkdir(self.IMG_DIR)
|
||||
|
||||
def log_verbose(self, verbose: bool):
|
||||
self.LOG_VERBOSE = verbose
|
||||
|
||||
def chatchat_root(self, root: str):
|
||||
self.CHATCHAT_ROOT = root
|
||||
|
||||
def data_path(self, path: str):
|
||||
self.DATA_PATH = path
|
||||
if not os.path.exists(self.DATA_PATH):
|
||||
os.mkdir(self.DATA_PATH)
|
||||
# 复制_DATA_PATH数据到DATA_PATH
|
||||
os.system(f"cp -r {self._DATA_PATH} {self.DATA_PATH}")
|
||||
|
||||
self._init_data_dir()
|
||||
|
||||
def log_format(self, log_format: str):
|
||||
self.LOG_FORMAT = log_format
|
||||
logging.basicConfig(format=self.LOG_FORMAT)
|
||||
|
||||
def _init_data_dir(self):
|
||||
logger.info(f"Init data dir: {self.DATA_PATH}")
|
||||
# nltk 模型存储路径
|
||||
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
|
||||
import nltk
|
||||
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
|
||||
# 日志存储路径
|
||||
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
|
||||
if not os.path.exists(self.LOG_PATH):
|
||||
os.mkdir(self.LOG_PATH)
|
||||
|
||||
# 模型生成内容(图片、视频、音频等)保存位置
|
||||
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
|
||||
if not os.path.exists(self.MEDIA_PATH):
|
||||
os.mkdir(self.MEDIA_PATH)
|
||||
os.mkdir(os.path.join(self.MEDIA_PATH, "image"))
|
||||
os.mkdir(os.path.join(self.MEDIA_PATH, "audio"))
|
||||
os.mkdir(os.path.join(self.MEDIA_PATH, "video"))
|
||||
|
||||
# 临时文件目录,主要用于文件对话
|
||||
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
|
||||
if not os.path.exists(self.BASE_TEMP_DIR):
|
||||
os.mkdir(self.BASE_TEMP_DIR)
|
||||
|
||||
logger.info(f"Init data dir: {self.DATA_PATH} success.")
|
||||
|
||||
def get_config(self) -> ConfigBasic:
|
||||
config = ConfigBasic()
|
||||
config.log_verbose = self.LOG_VERBOSE
|
||||
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
|
||||
config.DATA_PATH = self.DATA_PATH
|
||||
config.IMG_DIR = self.IMG_DIR
|
||||
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
|
||||
config.LOG_FORMAT = self.LOG_FORMAT
|
||||
config.LOG_PATH = self.LOG_PATH
|
||||
config.MEDIA_PATH = self.MEDIA_PATH
|
||||
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
|
||||
return config
|
||||
|
||||
|
||||
class ConfigWorkSpace:
|
||||
"""
|
||||
工作空间的配置预设,提供ConfigBasic建造方法产生实例。
|
||||
该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
|
||||
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。
|
||||
注意:不存在则读取默认
|
||||
"""
|
||||
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
|
||||
|
||||
def __init__(self):
|
||||
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
|
||||
if not os.path.exists(self.workspace):
|
||||
os.makedirs(self.workspace, exist_ok=True)
|
||||
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
|
||||
# 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化
|
||||
with open(self.workspace_config, "w") as f:
|
||||
config_json = json.loads(f.read())
|
||||
|
||||
if config_json:
|
||||
|
||||
_config_factory = ConfigBasicFactory()
|
||||
if config_json.get("log_verbose"):
|
||||
_config_factory.log_verbose(config_json.get("log_verbose"))
|
||||
if config_json.get("CHATCHAT_ROOT"):
|
||||
_config_factory.chatchat_root(config_json.get("CHATCHAT_ROOT"))
|
||||
if config_json.get("DATA_PATH"):
|
||||
_config_factory.data_path(config_json.get("DATA_PATH"))
|
||||
if config_json.get("LOG_FORMAT"):
|
||||
_config_factory.log_format(config_json.get("LOG_FORMAT"))
|
||||
|
||||
self._config_factory = _config_factory
|
||||
|
||||
def get_config(self) -> ConfigBasic:
|
||||
return self._config_factory.get_config()
|
||||
|
||||
def set_log_verbose(self, verbose: bool):
|
||||
self._config_factory.log_verbose(verbose)
|
||||
self._store_config()
|
||||
|
||||
def set_chatchat_root(self, root: str):
|
||||
self._config_factory.chatchat_root(root)
|
||||
self._store_config()
|
||||
|
||||
def set_data_path(self, path: str):
|
||||
self._config_factory.data_path(path)
|
||||
self._store_config()
|
||||
|
||||
def set_log_format(self, log_format: str):
|
||||
self._config_factory.log_format(log_format)
|
||||
self._store_config()
|
||||
|
||||
def _store_config(self):
|
||||
with open(self.workspace_config, "w") as f:
|
||||
config = self._config_factory.get_config()
|
||||
config_json = {
|
||||
"log_verbose": config.log_verbose,
|
||||
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
|
||||
"DATA_PATH": config.DATA_PATH,
|
||||
"LOG_FORMAT": config.LOG_FORMAT
|
||||
}
|
||||
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
|
||||
|
||||
|
||||
def _load_mod(mod, attr):
|
||||
attr_cfg = None
|
||||
for name, obj in vars(mod).items():
|
||||
@ -38,7 +196,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
|
||||
)
|
||||
user_import = False
|
||||
if user_import:
|
||||
|
||||
# Dynamic loading {config}.py file
|
||||
py_path = os.path.join(user_config_path, import_config_mod + ".py")
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
@ -69,7 +226,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
|
||||
)
|
||||
raise RuntimeError(f"Failed to load user config from {user_config_path}")
|
||||
# 当前文件路径
|
||||
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
|
||||
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
|
||||
|
||||
spec = importlib.util.spec_from_file_location(f"*",
|
||||
py_path)
|
||||
@ -118,6 +275,7 @@ def _import_data_path() -> Any:
|
||||
|
||||
return DATA_PATH
|
||||
|
||||
|
||||
def _import_img_dir() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
@ -285,6 +443,7 @@ def _import_db_root_path() -> Any:
|
||||
|
||||
return DB_ROOT_PATH
|
||||
|
||||
|
||||
def _import_sqlalchemy_database_uri() -> Any:
|
||||
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
|
||||
load_mod = kb_config_load.get("load_mod")
|
||||
@ -627,4 +786,8 @@ __all__ = [
|
||||
"API_SERVER",
|
||||
|
||||
|
||||
"ConfigBasic",
|
||||
"ConfigBasicFactory",
|
||||
"ConfigWorkSpace",
|
||||
|
||||
]
|
||||
|
||||
90
libs/chatchat-server/tests/conftest.py
Normal file
90
libs/chatchat-server/tests/conftest.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Configuration for unit tests."""
|
||||
import logging
|
||||
from importlib import util
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
action="store_true",
|
||||
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--only-core",
|
||||
action="store_true",
|
||||
help="Only run core tests. Never runs any extended tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: Dict[str, bool] = {}
|
||||
|
||||
only_extended = config.getoption("--only-extended") or False
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
if only_core:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||
continue
|
||||
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
try:
|
||||
installed = util.find_spec(pkg) is not None
|
||||
except Exception:
|
||||
installed = False
|
||||
required_pkgs_info[pkg] = installed
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
if only_extended:
|
||||
pytest.fail(
|
||||
f"Package `{pkg}` is not installed but is required for "
|
||||
f"extended tests. Please install the given package and "
|
||||
f"try again.",
|
||||
)
|
||||
|
||||
else:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
||||
|
||||
|
||||
20
libs/chatchat-server/tests/unit_tests/config/test_config.py
Normal file
20
libs/chatchat-server/tests/unit_tests/config/test_config.py
Normal file
@ -0,0 +1,20 @@
|
||||
from chatchat.configs import ConfigBasicFactory, ConfigBasic
|
||||
import os
|
||||
|
||||
|
||||
def test_config_factory_def():
|
||||
test_config_factory = ConfigBasicFactory()
|
||||
config: ConfigBasic = test_config_factory.get_config()
|
||||
assert config is not None
|
||||
assert config.log_verbose is False
|
||||
assert config.CHATCHAT_ROOT is not None
|
||||
assert config.DATA_PATH is not None
|
||||
assert config.IMG_DIR is not None
|
||||
assert config.NLTK_DATA_PATH is not None
|
||||
assert config.LOG_FORMAT is not None
|
||||
assert config.LOG_PATH is not None
|
||||
assert config.MEDIA_PATH is not None
|
||||
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
||||
Loading…
x
Reference in New Issue
Block a user