用户工作空间操作 (#4156)

工作空间的配置预设,提供ConfigBasic建造方法产生实例。
  该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
  工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。
  注意:不存在则读取默认

提供了操作入口
指令` chatchat-config` 工作空间配置

options:
```
  -h, --help            show this help message and exit
  -v {true,false}, --verbose {true,false}
                        是否开启详细日志
  -d DATA, --data DATA  数据存放路径
  -f FORMAT, --format FORMAT
                        日志格式
  --clear               清除配置
```
This commit is contained in:
glide-the 2024-06-09 19:59:54 +08:00 committed by GitHub
parent 38ca2edb41
commit 84bafe9723
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 436 additions and 70 deletions

View File

@ -0,0 +1,47 @@
from chatchat.configs import config_workspace as workspace
def main():
import argparse
parser = argparse.ArgumentParser(description="指令` chatchat-config` 工作空间配置")
# 只能选择true或false
parser.add_argument(
"-v",
"--verbose",
choices=["true", "false"],
help="是否开启详细日志"
)
parser.add_argument(
"-d",
"--data",
help="数据存放路径"
)
parser.add_argument(
"-f",
"--format",
help="日志格式"
)
parser.add_argument(
"--clear",
action="store_true",
help="清除配置"
)
args = parser.parse_args()
if args.verbose:
if args.verbose.lower() == "true":
workspace.set_log_verbose(True)
else:
workspace.set_log_verbose(False)
if args.data:
workspace.set_data_path(args.data)
if args.format:
workspace.set_log_format(args.format)
if args.clear:
workspace.clear()
print(workspace.get_config())
if __name__ == "__main__":
main()

View File

@ -1,8 +1,10 @@
import importlib
import importlib.util
import os
from pathlib import Path
from typing import Dict, Any
import json
import logging
logger = logging.getLogger()
@ -38,7 +40,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
)
user_import = False
if user_import:
# Dynamic loading {config}.py file
py_path = os.path.join(user_config_path, import_config_mod + ".py")
spec = importlib.util.spec_from_file_location(
@ -69,7 +70,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
)
raise RuntimeError(f"Failed to load user config from {user_config_path}")
# 当前文件路径
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
spec = importlib.util.spec_from_file_location(f"*",
py_path)
@ -95,75 +96,108 @@ CONFIG_IMPORTS = {
}
def _import_ConfigBasic() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigBasic = load_mod(basic_config_load.get("module"), "ConfigBasic")
return ConfigBasic
def _import_ConfigBasicFactory() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigBasicFactory = load_mod(basic_config_load.get("module"), "ConfigBasicFactory")
return ConfigBasicFactory
def _import_ConfigWorkSpace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace")
return ConfigWorkSpace
def _import_config_workspace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace
def _import_log_verbose() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
log_verbose = load_mod(basic_config_load.get("module"), "log_verbose")
return log_verbose
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().log_verbose
def _import_chatchat_root() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
CHATCHAT_ROOT = load_mod(basic_config_load.get("module"), "CHATCHAT_ROOT")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return CHATCHAT_ROOT
return config_workspace.get_config().CHATCHAT_ROOT
def _import_data_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
DATA_PATH = load_mod(basic_config_load.get("module"), "DATA_PATH")
return DATA_PATH
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().DATA_PATH
def _import_img_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
IMG_DIR = load_mod(basic_config_load.get("module"), "IMG_DIR")
return IMG_DIR
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().IMG_DIR
def _import_nltk_data_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
NLTK_DATA_PATH = load_mod(basic_config_load.get("module"), "NLTK_DATA_PATH")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return NLTK_DATA_PATH
return config_workspace.get_config().NLTK_DATA_PATH
def _import_log_format() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
LOG_FORMAT = load_mod(basic_config_load.get("module"), "LOG_FORMAT")
return LOG_FORMAT
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().LOG_FORMAT
def _import_log_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
LOG_PATH = load_mod(basic_config_load.get("module"), "LOG_PATH")
return LOG_PATH
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().LOG_PATH
def _import_media_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
MEDIA_PATH = load_mod(basic_config_load.get("module"), "MEDIA_PATH")
return MEDIA_PATH
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().MEDIA_PATH
def _import_base_temp_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
BASE_TEMP_DIR = load_mod(basic_config_load.get("module"), "BASE_TEMP_DIR")
return BASE_TEMP_DIR
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().BASE_TEMP_DIR
def _import_default_knowledge_base() -> Any:
@ -285,6 +319,7 @@ def _import_db_root_path() -> Any:
return DB_ROOT_PATH
def _import_sqlalchemy_database_uri() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod")
@ -478,7 +513,15 @@ def _import_api_server() -> Any:
def __getattr__(name: str) -> Any:
if name == "log_verbose":
if name == "ConfigBasic":
return _import_ConfigBasic()
elif name == "ConfigBasicFactory":
return _import_ConfigBasicFactory()
elif name == "ConfigWorkSpace":
return _import_ConfigWorkSpace()
elif name == "config_workspace":
return _import_config_workspace()
elif name == "log_verbose":
return _import_log_verbose()
elif name == "CHATCHAT_ROOT":
return _import_chatchat_root()
@ -578,6 +621,7 @@ VERSION = "v0.3.0-preview"
__all__ = [
"VERSION",
"config_workspace",
"log_verbose",
"CHATCHAT_ROOT",
"DATA_PATH",
@ -626,5 +670,8 @@ __all__ = [
"WEBUI_SERVER",
"API_SERVER",
"ConfigBasic",
"ConfigBasicFactory",
"ConfigWorkSpace",
]

View File

@ -1,55 +1,180 @@
import logging
import os
import json
from pathlib import Path
import logging
import langchain
# 是否显示详细日志
log_verbose = False
langchain.verbose = False
# 通常情况下不需要更改以下内容
# chatchat 项目根目录
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
# 用户数据根目录
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
if not os.path.exists(DATA_PATH):
os.mkdir(DATA_PATH)
# 项目相关图片
IMG_DIR = os.path.join(CHATCHAT_ROOT, "img")
if not os.path.exists(IMG_DIR):
os.mkdir(IMG_DIR)
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 日志存储路径
LOG_PATH = os.path.join(DATA_PATH, "logs")
if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH)
class ConfigBasic:
log_verbose: bool
"""是否开启日志详细信息"""
CHATCHAT_ROOT: str
"""项目根目录"""
DATA_PATH: str
"""用户数据根目录"""
IMG_DIR: str
"""项目相关图片"""
NLTK_DATA_PATH: str
"""nltk 模型存储路径"""
LOG_FORMAT: str
"""日志格式"""
LOG_PATH: str
"""日志存储路径"""
MEDIA_PATH: str
"""模型生成内容(图片、视频、音频等)保存位置"""
BASE_TEMP_DIR: str
"""临时文件目录,主要用于文件对话"""
# 模型生成内容(图片、视频、音频等)保存位置
MEDIA_PATH = os.path.join(DATA_PATH, "media")
if not os.path.exists(MEDIA_PATH):
os.mkdir(MEDIA_PATH)
os.mkdir(os.path.join(MEDIA_PATH, "image"))
os.mkdir(os.path.join(MEDIA_PATH, "audio"))
os.mkdir(os.path.join(MEDIA_PATH, "video"))
def __str__(self):
return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})"
# 临时文件目录,主要用于文件对话
BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
if not os.path.exists(BASE_TEMP_DIR):
os.mkdir(BASE_TEMP_DIR)
class ConfigBasicFactory:
"""Basic config for ChatChat """
def __init__(self):
# 日志格式
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logging.basicConfig(format=self.LOG_FORMAT)
self.LOG_VERBOSE = False
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
# 用户数据根目录
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
if not os.path.exists(self._DATA_PATH):
os.makedirs(self._DATA_PATH, exist_ok=True)
self._init_data_dir()
# 项目相关图片
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
if not os.path.exists(self.IMG_DIR):
os.makedirs(self.IMG_DIR, exist_ok=True)
def log_verbose(self, verbose: bool):
self.LOG_VERBOSE = verbose
def data_path(self, path: str):
self.DATA_PATH = path
if not os.path.exists(self.DATA_PATH):
os.makedirs(self.DATA_PATH, exist_ok=True)
# 复制_DATA_PATH数据到DATA_PATH
if self._DATA_PATH != self.DATA_PATH:
os.system(f"cp -r {self._DATA_PATH}/* {self.DATA_PATH}")
self._init_data_dir()
def log_format(self, log_format: str):
self.LOG_FORMAT = log_format
logging.basicConfig(format=self.LOG_FORMAT)
def _init_data_dir(self):
logger.info(f"Init data dir: {self.DATA_PATH}")
# nltk 模型存储路径
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
import nltk
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
# 日志存储路径
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
if not os.path.exists(self.LOG_PATH):
os.makedirs(self.LOG_PATH, exist_ok=True)
# 模型生成内容(图片、视频、音频等)保存位置
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
if not os.path.exists(self.MEDIA_PATH):
os.makedirs(self.MEDIA_PATH, exist_ok=True)
os.makedirs(os.path.join(self.MEDIA_PATH, "image"), exist_ok=True)
os.makedirs(os.path.join(self.MEDIA_PATH, "audio"), exist_ok=True)
os.makedirs(os.path.join(self.MEDIA_PATH, "video"), exist_ok=True)
# 临时文件目录,主要用于文件对话
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
if not os.path.exists(self.BASE_TEMP_DIR):
os.makedirs(self.BASE_TEMP_DIR, exist_ok=True)
logger.info(f"Init data dir: {self.DATA_PATH} success.")
def get_config(self) -> ConfigBasic:
config = ConfigBasic()
config.log_verbose = self.LOG_VERBOSE
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
config.DATA_PATH = self.DATA_PATH
config.IMG_DIR = self.IMG_DIR
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
config.LOG_FORMAT = self.LOG_FORMAT
config.LOG_PATH = self.LOG_PATH
config.MEDIA_PATH = self.MEDIA_PATH
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
return config
class ConfigWorkSpace:
"""
工作空间的配置预设提供ConfigBasic建造方法产生实例
该类的实例对象用于存储工作空间的配置信息如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中
注意不存在则读取默认
"""
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
def __init__(self):
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
if not os.path.exists(self.workspace):
os.makedirs(self.workspace, exist_ok=True)
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
# 初始化工作空间配置转换成json格式实现ConfigBasic的实例化
config_json = self._load_config()
if config_json:
_config_factory = ConfigBasicFactory()
if config_json.get("log_verbose"):
_config_factory.log_verbose(config_json.get("log_verbose"))
if config_json.get("DATA_PATH"):
_config_factory.data_path(config_json.get("DATA_PATH"))
if config_json.get("LOG_FORMAT"):
_config_factory.log_format(config_json.get("LOG_FORMAT"))
self._config_factory = _config_factory
def get_config(self) -> ConfigBasic:
return self._config_factory.get_config()
def set_log_verbose(self, verbose: bool):
self._config_factory.log_verbose(verbose)
self._store_config()
def set_data_path(self, path: str):
self._config_factory.data_path(path)
self._store_config()
def set_log_format(self, log_format: str):
self._config_factory.log_format(log_format)
self._store_config()
def clear(self):
logger.info("Clear workspace config.")
os.remove(self.workspace_config)
def _load_config(self):
try:
with open(self.workspace_config, "r") as f:
return json.loads(f.read())
except FileNotFoundError:
return None
def _store_config(self):
with open(self.workspace_config, "w") as f:
config = self._config_factory.get_config()
config_json = {
"log_verbose": config.log_verbose,
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
"DATA_PATH": config.DATA_PATH,
"LOG_FORMAT": config.LOG_FORMAT
}
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
config_workspace: ConfigWorkSpace = ConfigWorkSpace()

View File

@ -11,6 +11,7 @@ packages = [
[tool.poetry.scripts]
chatchat = 'chatchat.startup:main'
chatchat-kb = 'chatchat.init_database:main'
chatchat-config = 'chatchat.config_work_space:main'
[tool.poetry.dependencies]
python = ">=3.8.1,<3.12,!=3.9.7"

View File

@ -0,0 +1,90 @@
"""Configuration for unit tests."""
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)

View File

@ -0,0 +1,56 @@
from pathlib import Path
from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace
import os
def test_config_factory_def():
test_config_factory = ConfigBasicFactory()
config: ConfigBasic = test_config_factory.get_config()
assert config is not None
assert config.log_verbose is False
assert config.CHATCHAT_ROOT is not None
assert config.DATA_PATH is not None
assert config.IMG_DIR is not None
assert config.NLTK_DATA_PATH is not None
assert config.LOG_FORMAT is not None
assert config.LOG_PATH is not None
assert config.MEDIA_PATH is not None
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
def test_workspace():
config_workspace = ConfigWorkSpace()
assert config_workspace.get_config() is not None
base_root = os.path.join(Path(__file__).absolute().parent, "chatchat")
config_workspace.set_data_path(os.path.join(base_root, "data"))
config_workspace.set_log_verbose(True)
config_workspace.set_log_format(" %(message)s")
config: ConfigBasic = config_workspace.get_config()
assert config.log_verbose is True
assert config.DATA_PATH == os.path.join(base_root, "data")
assert config.IMG_DIR is not None
assert config.NLTK_DATA_PATH == os.path.join(base_root, "data", "nltk_data")
assert config.LOG_FORMAT == " %(message)s"
assert config.LOG_PATH == os.path.join(base_root, "data", "logs")
assert config.MEDIA_PATH == os.path.join(base_root, "data", "media")
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
config_workspace.clear()
def test_workspace_default():
from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH)
assert log_verbose is False
assert DATA_PATH is not None
assert IMG_DIR is not None
assert NLTK_DATA_PATH is not None
assert LOG_FORMAT is not None
assert LOG_PATH is not None
assert MEDIA_PATH is not None