mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
用户工作空间操作 (#4156)
工作空间的配置预设,提供ConfigBasic建造方法产生实例。
该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。
注意:不存在则读取默认
提供了操作入口
指令` chatchat-config` 工作空间配置
options:
```
-h, --help show this help message and exit
-v {true,false}, --verbose {true,false}
是否开启详细日志
-d DATA, --data DATA 数据存放路径
-f FORMAT, --format FORMAT
日志格式
--clear 清除配置
```
This commit is contained in:
parent
38ca2edb41
commit
84bafe9723
47
libs/chatchat-server/chatchat/config_work_space.py
Normal file
47
libs/chatchat-server/chatchat/config_work_space.py
Normal file
@ -0,0 +1,47 @@
|
||||
from chatchat.configs import config_workspace as workspace
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="指令` chatchat-config` 工作空间配置")
|
||||
# 只能选择true或false
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
choices=["true", "false"],
|
||||
help="是否开启详细日志"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--data",
|
||||
help="数据存放路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--format",
|
||||
help="日志格式"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear",
|
||||
action="store_true",
|
||||
help="清除配置"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
if args.verbose.lower() == "true":
|
||||
workspace.set_log_verbose(True)
|
||||
else:
|
||||
workspace.set_log_verbose(False)
|
||||
if args.data:
|
||||
workspace.set_data_path(args.data)
|
||||
if args.format:
|
||||
workspace.set_log_format(args.format)
|
||||
if args.clear:
|
||||
workspace.clear()
|
||||
print(workspace.get_config())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,8 +1,10 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
@ -38,7 +40,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
|
||||
)
|
||||
user_import = False
|
||||
if user_import:
|
||||
|
||||
# Dynamic loading {config}.py file
|
||||
py_path = os.path.join(user_config_path, import_config_mod + ".py")
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
@ -69,7 +70,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
|
||||
)
|
||||
raise RuntimeError(f"Failed to load user config from {user_config_path}")
|
||||
# 当前文件路径
|
||||
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
|
||||
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
|
||||
|
||||
spec = importlib.util.spec_from_file_location(f"*",
|
||||
py_path)
|
||||
@ -95,75 +96,108 @@ CONFIG_IMPORTS = {
|
||||
}
|
||||
|
||||
|
||||
def _import_ConfigBasic() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
ConfigBasic = load_mod(basic_config_load.get("module"), "ConfigBasic")
|
||||
|
||||
return ConfigBasic
|
||||
|
||||
|
||||
def _import_ConfigBasicFactory() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
ConfigBasicFactory = load_mod(basic_config_load.get("module"), "ConfigBasicFactory")
|
||||
|
||||
return ConfigBasicFactory
|
||||
|
||||
|
||||
def _import_ConfigWorkSpace() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace")
|
||||
|
||||
return ConfigWorkSpace
|
||||
|
||||
|
||||
def _import_config_workspace() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
return config_workspace
|
||||
|
||||
def _import_log_verbose() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
log_verbose = load_mod(basic_config_load.get("module"), "log_verbose")
|
||||
|
||||
return log_verbose
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
return config_workspace.get_config().log_verbose
|
||||
|
||||
|
||||
def _import_chatchat_root() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
CHATCHAT_ROOT = load_mod(basic_config_load.get("module"), "CHATCHAT_ROOT")
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return CHATCHAT_ROOT
|
||||
return config_workspace.get_config().CHATCHAT_ROOT
|
||||
|
||||
|
||||
def _import_data_path() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
DATA_PATH = load_mod(basic_config_load.get("module"), "DATA_PATH")
|
||||
|
||||
return DATA_PATH
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
return config_workspace.get_config().DATA_PATH
|
||||
|
||||
|
||||
def _import_img_dir() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
IMG_DIR = load_mod(basic_config_load.get("module"), "IMG_DIR")
|
||||
|
||||
return IMG_DIR
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return config_workspace.get_config().IMG_DIR
|
||||
|
||||
|
||||
def _import_nltk_data_path() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
NLTK_DATA_PATH = load_mod(basic_config_load.get("module"), "NLTK_DATA_PATH")
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return NLTK_DATA_PATH
|
||||
return config_workspace.get_config().NLTK_DATA_PATH
|
||||
|
||||
|
||||
def _import_log_format() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
LOG_FORMAT = load_mod(basic_config_load.get("module"), "LOG_FORMAT")
|
||||
|
||||
return LOG_FORMAT
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return config_workspace.get_config().LOG_FORMAT
|
||||
|
||||
|
||||
def _import_log_path() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
LOG_PATH = load_mod(basic_config_load.get("module"), "LOG_PATH")
|
||||
|
||||
return LOG_PATH
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return config_workspace.get_config().LOG_PATH
|
||||
|
||||
|
||||
def _import_media_path() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
MEDIA_PATH = load_mod(basic_config_load.get("module"), "MEDIA_PATH")
|
||||
|
||||
return MEDIA_PATH
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return config_workspace.get_config().MEDIA_PATH
|
||||
|
||||
|
||||
def _import_base_temp_dir() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
BASE_TEMP_DIR = load_mod(basic_config_load.get("module"), "BASE_TEMP_DIR")
|
||||
|
||||
return BASE_TEMP_DIR
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
return config_workspace.get_config().BASE_TEMP_DIR
|
||||
|
||||
|
||||
def _import_default_knowledge_base() -> Any:
|
||||
@ -285,6 +319,7 @@ def _import_db_root_path() -> Any:
|
||||
|
||||
return DB_ROOT_PATH
|
||||
|
||||
|
||||
def _import_sqlalchemy_database_uri() -> Any:
|
||||
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
|
||||
load_mod = kb_config_load.get("load_mod")
|
||||
@ -478,7 +513,15 @@ def _import_api_server() -> Any:
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "log_verbose":
|
||||
if name == "ConfigBasic":
|
||||
return _import_ConfigBasic()
|
||||
elif name == "ConfigBasicFactory":
|
||||
return _import_ConfigBasicFactory()
|
||||
elif name == "ConfigWorkSpace":
|
||||
return _import_ConfigWorkSpace()
|
||||
elif name == "config_workspace":
|
||||
return _import_config_workspace()
|
||||
elif name == "log_verbose":
|
||||
return _import_log_verbose()
|
||||
elif name == "CHATCHAT_ROOT":
|
||||
return _import_chatchat_root()
|
||||
@ -578,6 +621,7 @@ VERSION = "v0.3.0-preview"
|
||||
|
||||
__all__ = [
|
||||
"VERSION",
|
||||
"config_workspace",
|
||||
"log_verbose",
|
||||
"CHATCHAT_ROOT",
|
||||
"DATA_PATH",
|
||||
@ -626,5 +670,8 @@ __all__ = [
|
||||
"WEBUI_SERVER",
|
||||
"API_SERVER",
|
||||
|
||||
"ConfigBasic",
|
||||
"ConfigBasicFactory",
|
||||
"ConfigWorkSpace",
|
||||
|
||||
]
|
||||
|
||||
@ -1,55 +1,180 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
import langchain
|
||||
|
||||
|
||||
# 是否显示详细日志
|
||||
log_verbose = False
|
||||
langchain.verbose = False
|
||||
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# chatchat 项目根目录
|
||||
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
||||
|
||||
# 用户数据根目录
|
||||
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
|
||||
if not os.path.exists(DATA_PATH):
|
||||
os.mkdir(DATA_PATH)
|
||||
|
||||
# 项目相关图片
|
||||
IMG_DIR = os.path.join(CHATCHAT_ROOT, "img")
|
||||
if not os.path.exists(IMG_DIR):
|
||||
os.mkdir(IMG_DIR)
|
||||
|
||||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
# 日志格式
|
||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logging.basicConfig(format=LOG_FORMAT)
|
||||
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.path.join(DATA_PATH, "logs")
|
||||
if not os.path.exists(LOG_PATH):
|
||||
os.mkdir(LOG_PATH)
|
||||
class ConfigBasic:
|
||||
log_verbose: bool
|
||||
"""是否开启日志详细信息"""
|
||||
CHATCHAT_ROOT: str
|
||||
"""项目根目录"""
|
||||
DATA_PATH: str
|
||||
"""用户数据根目录"""
|
||||
IMG_DIR: str
|
||||
"""项目相关图片"""
|
||||
NLTK_DATA_PATH: str
|
||||
"""nltk 模型存储路径"""
|
||||
LOG_FORMAT: str
|
||||
"""日志格式"""
|
||||
LOG_PATH: str
|
||||
"""日志存储路径"""
|
||||
MEDIA_PATH: str
|
||||
"""模型生成内容(图片、视频、音频等)保存位置"""
|
||||
BASE_TEMP_DIR: str
|
||||
"""临时文件目录,主要用于文件对话"""
|
||||
|
||||
# 模型生成内容(图片、视频、音频等)保存位置
|
||||
MEDIA_PATH = os.path.join(DATA_PATH, "media")
|
||||
if not os.path.exists(MEDIA_PATH):
|
||||
os.mkdir(MEDIA_PATH)
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "image"))
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "audio"))
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "video"))
|
||||
def __str__(self):
|
||||
return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})"
|
||||
|
||||
# 临时文件目录,主要用于文件对话
|
||||
BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
|
||||
if not os.path.exists(BASE_TEMP_DIR):
|
||||
os.mkdir(BASE_TEMP_DIR)
|
||||
|
||||
class ConfigBasicFactory:
|
||||
"""Basic config for ChatChat """
|
||||
|
||||
def __init__(self):
|
||||
# 日志格式
|
||||
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logging.basicConfig(format=self.LOG_FORMAT)
|
||||
self.LOG_VERBOSE = False
|
||||
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
||||
# 用户数据根目录
|
||||
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||
if not os.path.exists(self._DATA_PATH):
|
||||
os.makedirs(self._DATA_PATH, exist_ok=True)
|
||||
|
||||
self._init_data_dir()
|
||||
|
||||
# 项目相关图片
|
||||
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
|
||||
if not os.path.exists(self.IMG_DIR):
|
||||
os.makedirs(self.IMG_DIR, exist_ok=True)
|
||||
|
||||
def log_verbose(self, verbose: bool):
|
||||
self.LOG_VERBOSE = verbose
|
||||
|
||||
def data_path(self, path: str):
|
||||
self.DATA_PATH = path
|
||||
if not os.path.exists(self.DATA_PATH):
|
||||
os.makedirs(self.DATA_PATH, exist_ok=True)
|
||||
# 复制_DATA_PATH数据到DATA_PATH
|
||||
if self._DATA_PATH != self.DATA_PATH:
|
||||
os.system(f"cp -r {self._DATA_PATH}/* {self.DATA_PATH}")
|
||||
|
||||
self._init_data_dir()
|
||||
|
||||
def log_format(self, log_format: str):
|
||||
self.LOG_FORMAT = log_format
|
||||
logging.basicConfig(format=self.LOG_FORMAT)
|
||||
|
||||
def _init_data_dir(self):
|
||||
logger.info(f"Init data dir: {self.DATA_PATH}")
|
||||
# nltk 模型存储路径
|
||||
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
|
||||
import nltk
|
||||
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
|
||||
# 日志存储路径
|
||||
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
|
||||
if not os.path.exists(self.LOG_PATH):
|
||||
os.makedirs(self.LOG_PATH, exist_ok=True)
|
||||
|
||||
# 模型生成内容(图片、视频、音频等)保存位置
|
||||
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
|
||||
if not os.path.exists(self.MEDIA_PATH):
|
||||
os.makedirs(self.MEDIA_PATH, exist_ok=True)
|
||||
os.makedirs(os.path.join(self.MEDIA_PATH, "image"), exist_ok=True)
|
||||
os.makedirs(os.path.join(self.MEDIA_PATH, "audio"), exist_ok=True)
|
||||
os.makedirs(os.path.join(self.MEDIA_PATH, "video"), exist_ok=True)
|
||||
|
||||
# 临时文件目录,主要用于文件对话
|
||||
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
|
||||
if not os.path.exists(self.BASE_TEMP_DIR):
|
||||
os.makedirs(self.BASE_TEMP_DIR, exist_ok=True)
|
||||
|
||||
logger.info(f"Init data dir: {self.DATA_PATH} success.")
|
||||
|
||||
def get_config(self) -> ConfigBasic:
|
||||
config = ConfigBasic()
|
||||
config.log_verbose = self.LOG_VERBOSE
|
||||
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
|
||||
config.DATA_PATH = self.DATA_PATH
|
||||
config.IMG_DIR = self.IMG_DIR
|
||||
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
|
||||
config.LOG_FORMAT = self.LOG_FORMAT
|
||||
config.LOG_PATH = self.LOG_PATH
|
||||
config.MEDIA_PATH = self.MEDIA_PATH
|
||||
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
|
||||
return config
|
||||
|
||||
|
||||
class ConfigWorkSpace:
|
||||
"""
|
||||
工作空间的配置预设,提供ConfigBasic建造方法产生实例。
|
||||
该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
|
||||
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。
|
||||
注意:不存在则读取默认
|
||||
"""
|
||||
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
|
||||
|
||||
def __init__(self):
|
||||
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
|
||||
if not os.path.exists(self.workspace):
|
||||
os.makedirs(self.workspace, exist_ok=True)
|
||||
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
|
||||
# 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化
|
||||
|
||||
config_json = self._load_config()
|
||||
|
||||
if config_json:
|
||||
|
||||
_config_factory = ConfigBasicFactory()
|
||||
if config_json.get("log_verbose"):
|
||||
_config_factory.log_verbose(config_json.get("log_verbose"))
|
||||
if config_json.get("DATA_PATH"):
|
||||
_config_factory.data_path(config_json.get("DATA_PATH"))
|
||||
if config_json.get("LOG_FORMAT"):
|
||||
_config_factory.log_format(config_json.get("LOG_FORMAT"))
|
||||
|
||||
self._config_factory = _config_factory
|
||||
|
||||
def get_config(self) -> ConfigBasic:
|
||||
return self._config_factory.get_config()
|
||||
|
||||
def set_log_verbose(self, verbose: bool):
|
||||
self._config_factory.log_verbose(verbose)
|
||||
self._store_config()
|
||||
|
||||
def set_data_path(self, path: str):
|
||||
self._config_factory.data_path(path)
|
||||
self._store_config()
|
||||
|
||||
def set_log_format(self, log_format: str):
|
||||
self._config_factory.log_format(log_format)
|
||||
self._store_config()
|
||||
|
||||
def clear(self):
|
||||
logger.info("Clear workspace config.")
|
||||
os.remove(self.workspace_config)
|
||||
|
||||
def _load_config(self):
|
||||
try:
|
||||
with open(self.workspace_config, "r") as f:
|
||||
return json.loads(f.read())
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def _store_config(self):
|
||||
with open(self.workspace_config, "w") as f:
|
||||
config = self._config_factory.get_config()
|
||||
config_json = {
|
||||
"log_verbose": config.log_verbose,
|
||||
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
|
||||
"DATA_PATH": config.DATA_PATH,
|
||||
"LOG_FORMAT": config.LOG_FORMAT
|
||||
}
|
||||
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
|
||||
|
||||
|
||||
config_workspace: ConfigWorkSpace = ConfigWorkSpace()
|
||||
|
||||
@ -11,6 +11,7 @@ packages = [
|
||||
[tool.poetry.scripts]
|
||||
chatchat = 'chatchat.startup:main'
|
||||
chatchat-kb = 'chatchat.init_database:main'
|
||||
chatchat-config = 'chatchat.config_work_space:main'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<3.12,!=3.9.7"
|
||||
|
||||
90
libs/chatchat-server/tests/conftest.py
Normal file
90
libs/chatchat-server/tests/conftest.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Configuration for unit tests."""
|
||||
import logging
|
||||
from importlib import util
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
action="store_true",
|
||||
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--only-core",
|
||||
action="store_true",
|
||||
help="Only run core tests. Never runs any extended tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: Dict[str, bool] = {}
|
||||
|
||||
only_extended = config.getoption("--only-extended") or False
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
if only_core:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||
continue
|
||||
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
try:
|
||||
installed = util.find_spec(pkg) is not None
|
||||
except Exception:
|
||||
installed = False
|
||||
required_pkgs_info[pkg] = installed
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
if only_extended:
|
||||
pytest.fail(
|
||||
f"Package `{pkg}` is not installed but is required for "
|
||||
f"extended tests. Please install the given package and "
|
||||
f"try again.",
|
||||
)
|
||||
|
||||
else:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
||||
|
||||
|
||||
56
libs/chatchat-server/tests/unit_tests/config/test_config.py
Normal file
56
libs/chatchat-server/tests/unit_tests/config/test_config.py
Normal file
@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
|
||||
from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace
|
||||
import os
|
||||
|
||||
|
||||
def test_config_factory_def():
|
||||
test_config_factory = ConfigBasicFactory()
|
||||
config: ConfigBasic = test_config_factory.get_config()
|
||||
assert config is not None
|
||||
assert config.log_verbose is False
|
||||
assert config.CHATCHAT_ROOT is not None
|
||||
assert config.DATA_PATH is not None
|
||||
assert config.IMG_DIR is not None
|
||||
assert config.NLTK_DATA_PATH is not None
|
||||
assert config.LOG_FORMAT is not None
|
||||
assert config.LOG_PATH is not None
|
||||
assert config.MEDIA_PATH is not None
|
||||
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
||||
|
||||
|
||||
def test_workspace():
|
||||
config_workspace = ConfigWorkSpace()
|
||||
assert config_workspace.get_config() is not None
|
||||
base_root = os.path.join(Path(__file__).absolute().parent, "chatchat")
|
||||
config_workspace.set_data_path(os.path.join(base_root, "data"))
|
||||
config_workspace.set_log_verbose(True)
|
||||
config_workspace.set_log_format(" %(message)s")
|
||||
|
||||
config: ConfigBasic = config_workspace.get_config()
|
||||
assert config.log_verbose is True
|
||||
assert config.DATA_PATH == os.path.join(base_root, "data")
|
||||
assert config.IMG_DIR is not None
|
||||
assert config.NLTK_DATA_PATH == os.path.join(base_root, "data", "nltk_data")
|
||||
assert config.LOG_FORMAT == " %(message)s"
|
||||
assert config.LOG_PATH == os.path.join(base_root, "data", "logs")
|
||||
assert config.MEDIA_PATH == os.path.join(base_root, "data", "media")
|
||||
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
||||
config_workspace.clear()
|
||||
|
||||
|
||||
def test_workspace_default():
|
||||
from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH)
|
||||
assert log_verbose is False
|
||||
assert DATA_PATH is not None
|
||||
assert IMG_DIR is not None
|
||||
assert NLTK_DATA_PATH is not None
|
||||
assert LOG_FORMAT is not None
|
||||
assert LOG_PATH is not None
|
||||
assert MEDIA_PATH is not None
|
||||
Loading…
x
Reference in New Issue
Block a user