mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 23:43:30 +08:00
ConfigWorkSpace 单元测试
This commit is contained in:
parent
c67ce2c280
commit
1ebd1aa8cc
@ -10,162 +10,6 @@ import logging
|
|||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
class ConfigBasic:
|
|
||||||
log_verbose: bool
|
|
||||||
CHATCHAT_ROOT: str
|
|
||||||
DATA_PATH: str
|
|
||||||
IMG_DIR: str
|
|
||||||
NLTK_DATA_PATH: str
|
|
||||||
LOG_FORMAT: str
|
|
||||||
LOG_PATH: str
|
|
||||||
MEDIA_PATH: str
|
|
||||||
BASE_TEMP_DIR: str
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigBasicFactory:
|
|
||||||
"""Basic config for ChatChat """
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# 日志格式
|
|
||||||
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
|
||||||
logging.basicConfig(format=self.LOG_FORMAT)
|
|
||||||
self.LOG_VERBOSE = False
|
|
||||||
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
|
||||||
# 用户数据根目录
|
|
||||||
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
|
||||||
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
|
||||||
if not os.path.exists(self._DATA_PATH):
|
|
||||||
os.mkdir(self.DATA_PATH)
|
|
||||||
|
|
||||||
self._init_data_dir()
|
|
||||||
|
|
||||||
# 项目相关图片
|
|
||||||
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
|
|
||||||
if not os.path.exists(self.IMG_DIR):
|
|
||||||
os.mkdir(self.IMG_DIR)
|
|
||||||
|
|
||||||
def log_verbose(self, verbose: bool):
|
|
||||||
self.LOG_VERBOSE = verbose
|
|
||||||
|
|
||||||
def chatchat_root(self, root: str):
|
|
||||||
self.CHATCHAT_ROOT = root
|
|
||||||
|
|
||||||
def data_path(self, path: str):
|
|
||||||
self.DATA_PATH = path
|
|
||||||
if not os.path.exists(self.DATA_PATH):
|
|
||||||
os.mkdir(self.DATA_PATH)
|
|
||||||
# 复制_DATA_PATH数据到DATA_PATH
|
|
||||||
os.system(f"cp -r {self._DATA_PATH} {self.DATA_PATH}")
|
|
||||||
|
|
||||||
self._init_data_dir()
|
|
||||||
|
|
||||||
def log_format(self, log_format: str):
|
|
||||||
self.LOG_FORMAT = log_format
|
|
||||||
logging.basicConfig(format=self.LOG_FORMAT)
|
|
||||||
|
|
||||||
def _init_data_dir(self):
|
|
||||||
logger.info(f"Init data dir: {self.DATA_PATH}")
|
|
||||||
# nltk 模型存储路径
|
|
||||||
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
|
|
||||||
import nltk
|
|
||||||
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
|
|
||||||
# 日志存储路径
|
|
||||||
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
|
|
||||||
if not os.path.exists(self.LOG_PATH):
|
|
||||||
os.mkdir(self.LOG_PATH)
|
|
||||||
|
|
||||||
# 模型生成内容(图片、视频、音频等)保存位置
|
|
||||||
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
|
|
||||||
if not os.path.exists(self.MEDIA_PATH):
|
|
||||||
os.mkdir(self.MEDIA_PATH)
|
|
||||||
os.mkdir(os.path.join(self.MEDIA_PATH, "image"))
|
|
||||||
os.mkdir(os.path.join(self.MEDIA_PATH, "audio"))
|
|
||||||
os.mkdir(os.path.join(self.MEDIA_PATH, "video"))
|
|
||||||
|
|
||||||
# 临时文件目录,主要用于文件对话
|
|
||||||
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
|
|
||||||
if not os.path.exists(self.BASE_TEMP_DIR):
|
|
||||||
os.mkdir(self.BASE_TEMP_DIR)
|
|
||||||
|
|
||||||
logger.info(f"Init data dir: {self.DATA_PATH} success.")
|
|
||||||
|
|
||||||
def get_config(self) -> ConfigBasic:
|
|
||||||
config = ConfigBasic()
|
|
||||||
config.log_verbose = self.LOG_VERBOSE
|
|
||||||
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
|
|
||||||
config.DATA_PATH = self.DATA_PATH
|
|
||||||
config.IMG_DIR = self.IMG_DIR
|
|
||||||
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
|
|
||||||
config.LOG_FORMAT = self.LOG_FORMAT
|
|
||||||
config.LOG_PATH = self.LOG_PATH
|
|
||||||
config.MEDIA_PATH = self.MEDIA_PATH
|
|
||||||
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigWorkSpace:
|
|
||||||
"""
|
|
||||||
工作空间的配置预设,提供ConfigBasic建造方法产生实例。
|
|
||||||
该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
|
|
||||||
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。
|
|
||||||
注意:不存在则读取默认
|
|
||||||
"""
|
|
||||||
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
|
|
||||||
if not os.path.exists(self.workspace):
|
|
||||||
os.makedirs(self.workspace, exist_ok=True)
|
|
||||||
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
|
|
||||||
# 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化
|
|
||||||
with open(self.workspace_config, "w") as f:
|
|
||||||
config_json = json.loads(f.read())
|
|
||||||
|
|
||||||
if config_json:
|
|
||||||
|
|
||||||
_config_factory = ConfigBasicFactory()
|
|
||||||
if config_json.get("log_verbose"):
|
|
||||||
_config_factory.log_verbose(config_json.get("log_verbose"))
|
|
||||||
if config_json.get("CHATCHAT_ROOT"):
|
|
||||||
_config_factory.chatchat_root(config_json.get("CHATCHAT_ROOT"))
|
|
||||||
if config_json.get("DATA_PATH"):
|
|
||||||
_config_factory.data_path(config_json.get("DATA_PATH"))
|
|
||||||
if config_json.get("LOG_FORMAT"):
|
|
||||||
_config_factory.log_format(config_json.get("LOG_FORMAT"))
|
|
||||||
|
|
||||||
self._config_factory = _config_factory
|
|
||||||
|
|
||||||
def get_config(self) -> ConfigBasic:
|
|
||||||
return self._config_factory.get_config()
|
|
||||||
|
|
||||||
def set_log_verbose(self, verbose: bool):
|
|
||||||
self._config_factory.log_verbose(verbose)
|
|
||||||
self._store_config()
|
|
||||||
|
|
||||||
def set_chatchat_root(self, root: str):
|
|
||||||
self._config_factory.chatchat_root(root)
|
|
||||||
self._store_config()
|
|
||||||
|
|
||||||
def set_data_path(self, path: str):
|
|
||||||
self._config_factory.data_path(path)
|
|
||||||
self._store_config()
|
|
||||||
|
|
||||||
def set_log_format(self, log_format: str):
|
|
||||||
self._config_factory.log_format(log_format)
|
|
||||||
self._store_config()
|
|
||||||
|
|
||||||
def _store_config(self):
|
|
||||||
with open(self.workspace_config, "w") as f:
|
|
||||||
config = self._config_factory.get_config()
|
|
||||||
config_json = {
|
|
||||||
"log_verbose": config.log_verbose,
|
|
||||||
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
|
|
||||||
"DATA_PATH": config.DATA_PATH,
|
|
||||||
"LOG_FORMAT": config.LOG_FORMAT
|
|
||||||
}
|
|
||||||
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
|
|
||||||
|
|
||||||
|
|
||||||
def _load_mod(mod, attr):
|
def _load_mod(mod, attr):
|
||||||
attr_cfg = None
|
attr_cfg = None
|
||||||
for name, obj in vars(mod).items():
|
for name, obj in vars(mod).items():
|
||||||
@ -252,76 +96,102 @@ CONFIG_IMPORTS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ConfigBasic() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
ConfigBasic = load_mod(basic_config_load.get("module"), "ConfigBasic")
|
||||||
|
|
||||||
|
return ConfigBasic
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ConfigBasicFactory() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
ConfigBasicFactory = load_mod(basic_config_load.get("module"), "ConfigBasicFactory")
|
||||||
|
|
||||||
|
return ConfigBasicFactory
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ConfigWorkSpace() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace")
|
||||||
|
|
||||||
|
return ConfigWorkSpace
|
||||||
|
|
||||||
|
|
||||||
def _import_log_verbose() -> Any:
|
def _import_log_verbose() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
log_verbose = load_mod(basic_config_load.get("module"), "log_verbose")
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
return config_workspace.get_config().log_verbose
|
||||||
return log_verbose
|
|
||||||
|
|
||||||
|
|
||||||
def _import_chatchat_root() -> Any:
|
def _import_chatchat_root() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
CHATCHAT_ROOT = load_mod(basic_config_load.get("module"), "CHATCHAT_ROOT")
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
return CHATCHAT_ROOT
|
return config_workspace.get_config().CHATCHAT_ROOT
|
||||||
|
|
||||||
|
|
||||||
def _import_data_path() -> Any:
|
def _import_data_path() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
DATA_PATH = load_mod(basic_config_load.get("module"), "DATA_PATH")
|
|
||||||
|
|
||||||
return DATA_PATH
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
return config_workspace.get_config().DATA_PATH
|
||||||
|
|
||||||
|
|
||||||
def _import_img_dir() -> Any:
|
def _import_img_dir() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
IMG_DIR = load_mod(basic_config_load.get("module"), "IMG_DIR")
|
|
||||||
|
|
||||||
return IMG_DIR
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
|
return config_workspace.get_config().IMG_DIR
|
||||||
|
|
||||||
|
|
||||||
def _import_nltk_data_path() -> Any:
|
def _import_nltk_data_path() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
NLTK_DATA_PATH = load_mod(basic_config_load.get("module"), "NLTK_DATA_PATH")
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
return NLTK_DATA_PATH
|
return config_workspace.get_config().NLTK_DATA_PATH
|
||||||
|
|
||||||
|
|
||||||
def _import_log_format() -> Any:
|
def _import_log_format() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
LOG_FORMAT = load_mod(basic_config_load.get("module"), "LOG_FORMAT")
|
|
||||||
|
|
||||||
return LOG_FORMAT
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
|
return config_workspace.get_config().LOG_FORMAT
|
||||||
|
|
||||||
|
|
||||||
def _import_log_path() -> Any:
|
def _import_log_path() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
LOG_PATH = load_mod(basic_config_load.get("module"), "LOG_PATH")
|
|
||||||
|
|
||||||
return LOG_PATH
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
|
return config_workspace.get_config().LOG_PATH
|
||||||
|
|
||||||
|
|
||||||
def _import_media_path() -> Any:
|
def _import_media_path() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
MEDIA_PATH = load_mod(basic_config_load.get("module"), "MEDIA_PATH")
|
|
||||||
|
|
||||||
return MEDIA_PATH
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
|
return config_workspace.get_config().MEDIA_PATH
|
||||||
|
|
||||||
|
|
||||||
def _import_base_temp_dir() -> Any:
|
def _import_base_temp_dir() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
BASE_TEMP_DIR = load_mod(basic_config_load.get("module"), "BASE_TEMP_DIR")
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
return config_workspace.get_config().BASE_TEMP_DIR
|
||||||
return BASE_TEMP_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _import_default_knowledge_base() -> Any:
|
def _import_default_knowledge_base() -> Any:
|
||||||
@ -637,7 +507,13 @@ def _import_api_server() -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> Any:
|
def __getattr__(name: str) -> Any:
|
||||||
if name == "log_verbose":
|
if name == "ConfigBasic":
|
||||||
|
return _import_ConfigBasic()
|
||||||
|
elif name == "ConfigBasicFactory":
|
||||||
|
return _import_ConfigBasicFactory()
|
||||||
|
elif name == "ConfigWorkSpace":
|
||||||
|
return _import_ConfigWorkSpace()
|
||||||
|
elif name == "log_verbose":
|
||||||
return _import_log_verbose()
|
return _import_log_verbose()
|
||||||
elif name == "CHATCHAT_ROOT":
|
elif name == "CHATCHAT_ROOT":
|
||||||
return _import_chatchat_root()
|
return _import_chatchat_root()
|
||||||
@ -785,7 +661,6 @@ __all__ = [
|
|||||||
"WEBUI_SERVER",
|
"WEBUI_SERVER",
|
||||||
"API_SERVER",
|
"API_SERVER",
|
||||||
|
|
||||||
|
|
||||||
"ConfigBasic",
|
"ConfigBasic",
|
||||||
"ConfigBasicFactory",
|
"ConfigBasicFactory",
|
||||||
"ConfigWorkSpace",
|
"ConfigWorkSpace",
|
||||||
|
|||||||
@ -1,55 +1,177 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
|
||||||
import langchain
|
|
||||||
|
|
||||||
|
|
||||||
# 是否显示详细日志
|
|
||||||
log_verbose = False
|
|
||||||
langchain.verbose = False
|
|
||||||
|
|
||||||
# 通常情况下不需要更改以下内容
|
|
||||||
|
|
||||||
# chatchat 项目根目录
|
|
||||||
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
|
||||||
|
|
||||||
# 用户数据根目录
|
|
||||||
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
|
|
||||||
if not os.path.exists(DATA_PATH):
|
|
||||||
os.mkdir(DATA_PATH)
|
|
||||||
|
|
||||||
# 项目相关图片
|
|
||||||
IMG_DIR = os.path.join(CHATCHAT_ROOT, "img")
|
|
||||||
if not os.path.exists(IMG_DIR):
|
|
||||||
os.mkdir(IMG_DIR)
|
|
||||||
|
|
||||||
# nltk 模型存储路径
|
|
||||||
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
|
|
||||||
import nltk
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
||||||
|
|
||||||
# 日志格式
|
|
||||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
logging.basicConfig(format=LOG_FORMAT)
|
|
||||||
|
|
||||||
|
|
||||||
# 日志存储路径
|
class ConfigBasic:
|
||||||
LOG_PATH = os.path.join(DATA_PATH, "logs")
|
log_verbose: bool
|
||||||
if not os.path.exists(LOG_PATH):
|
"""是否开启日志详细信息"""
|
||||||
os.mkdir(LOG_PATH)
|
CHATCHAT_ROOT: str
|
||||||
|
"""项目根目录"""
|
||||||
|
DATA_PATH: str
|
||||||
|
"""用户数据根目录"""
|
||||||
|
IMG_DIR: str
|
||||||
|
"""项目相关图片"""
|
||||||
|
NLTK_DATA_PATH: str
|
||||||
|
"""nltk 模型存储路径"""
|
||||||
|
LOG_FORMAT: str
|
||||||
|
"""日志格式"""
|
||||||
|
LOG_PATH: str
|
||||||
|
"""日志存储路径"""
|
||||||
|
MEDIA_PATH: str
|
||||||
|
"""模型生成内容(图片、视频、音频等)保存位置"""
|
||||||
|
BASE_TEMP_DIR: str
|
||||||
|
"""临时文件目录,主要用于文件对话"""
|
||||||
|
|
||||||
# 模型生成内容(图片、视频、音频等)保存位置
|
|
||||||
MEDIA_PATH = os.path.join(DATA_PATH, "media")
|
|
||||||
if not os.path.exists(MEDIA_PATH):
|
|
||||||
os.mkdir(MEDIA_PATH)
|
|
||||||
os.mkdir(os.path.join(MEDIA_PATH, "image"))
|
|
||||||
os.mkdir(os.path.join(MEDIA_PATH, "audio"))
|
|
||||||
os.mkdir(os.path.join(MEDIA_PATH, "video"))
|
|
||||||
|
|
||||||
# 临时文件目录,主要用于文件对话
|
class ConfigBasicFactory:
|
||||||
BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
|
"""Basic config for ChatChat """
|
||||||
if not os.path.exists(BASE_TEMP_DIR):
|
|
||||||
os.mkdir(BASE_TEMP_DIR)
|
def __init__(self):
|
||||||
|
# 日志格式
|
||||||
|
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||||
|
logging.basicConfig(format=self.LOG_FORMAT)
|
||||||
|
self.LOG_VERBOSE = False
|
||||||
|
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
||||||
|
# 用户数据根目录
|
||||||
|
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||||
|
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||||
|
if not os.path.exists(self._DATA_PATH):
|
||||||
|
os.makedirs(self._DATA_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
self._init_data_dir()
|
||||||
|
|
||||||
|
# 项目相关图片
|
||||||
|
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
|
||||||
|
if not os.path.exists(self.IMG_DIR):
|
||||||
|
os.makedirs(self.IMG_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
def log_verbose(self, verbose: bool):
|
||||||
|
self.LOG_VERBOSE = verbose
|
||||||
|
|
||||||
|
def data_path(self, path: str):
|
||||||
|
self.DATA_PATH = path
|
||||||
|
if not os.path.exists(self.DATA_PATH):
|
||||||
|
os.makedirs(self.DATA_PATH, exist_ok=True)
|
||||||
|
# 复制_DATA_PATH数据到DATA_PATH
|
||||||
|
if self._DATA_PATH != self.DATA_PATH:
|
||||||
|
os.system(f"cp -r {self._DATA_PATH}/* {self.DATA_PATH}")
|
||||||
|
|
||||||
|
self._init_data_dir()
|
||||||
|
|
||||||
|
def log_format(self, log_format: str):
|
||||||
|
self.LOG_FORMAT = log_format
|
||||||
|
logging.basicConfig(format=self.LOG_FORMAT)
|
||||||
|
|
||||||
|
def _init_data_dir(self):
|
||||||
|
logger.info(f"Init data dir: {self.DATA_PATH}")
|
||||||
|
# nltk 模型存储路径
|
||||||
|
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
|
||||||
|
import nltk
|
||||||
|
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
# 日志存储路径
|
||||||
|
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
|
||||||
|
if not os.path.exists(self.LOG_PATH):
|
||||||
|
os.makedirs(self.LOG_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
# 模型生成内容(图片、视频、音频等)保存位置
|
||||||
|
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
|
||||||
|
if not os.path.exists(self.MEDIA_PATH):
|
||||||
|
os.makedirs(self.MEDIA_PATH, exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(self.MEDIA_PATH, "image"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(self.MEDIA_PATH, "audio"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(self.MEDIA_PATH, "video"), exist_ok=True)
|
||||||
|
|
||||||
|
# 临时文件目录,主要用于文件对话
|
||||||
|
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
|
||||||
|
if not os.path.exists(self.BASE_TEMP_DIR):
|
||||||
|
os.makedirs(self.BASE_TEMP_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info(f"Init data dir: {self.DATA_PATH} success.")
|
||||||
|
|
||||||
|
def get_config(self) -> ConfigBasic:
|
||||||
|
config = ConfigBasic()
|
||||||
|
config.log_verbose = self.LOG_VERBOSE
|
||||||
|
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
|
||||||
|
config.DATA_PATH = self.DATA_PATH
|
||||||
|
config.IMG_DIR = self.IMG_DIR
|
||||||
|
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
|
||||||
|
config.LOG_FORMAT = self.LOG_FORMAT
|
||||||
|
config.LOG_PATH = self.LOG_PATH
|
||||||
|
config.MEDIA_PATH = self.MEDIA_PATH
|
||||||
|
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigWorkSpace:
|
||||||
|
"""
|
||||||
|
工作空间的配置预设,提供ConfigBasic建造方法产生实例。
|
||||||
|
该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
|
||||||
|
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。
|
||||||
|
注意:不存在则读取默认
|
||||||
|
"""
|
||||||
|
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
|
||||||
|
if not os.path.exists(self.workspace):
|
||||||
|
os.makedirs(self.workspace, exist_ok=True)
|
||||||
|
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
|
||||||
|
# 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化
|
||||||
|
|
||||||
|
config_json = self._load_config()
|
||||||
|
|
||||||
|
if config_json:
|
||||||
|
|
||||||
|
_config_factory = ConfigBasicFactory()
|
||||||
|
if config_json.get("log_verbose"):
|
||||||
|
_config_factory.log_verbose(config_json.get("log_verbose"))
|
||||||
|
if config_json.get("DATA_PATH"):
|
||||||
|
_config_factory.data_path(config_json.get("DATA_PATH"))
|
||||||
|
if config_json.get("LOG_FORMAT"):
|
||||||
|
_config_factory.log_format(config_json.get("LOG_FORMAT"))
|
||||||
|
|
||||||
|
self._config_factory = _config_factory
|
||||||
|
|
||||||
|
def get_config(self) -> ConfigBasic:
|
||||||
|
return self._config_factory.get_config()
|
||||||
|
|
||||||
|
def set_log_verbose(self, verbose: bool):
|
||||||
|
self._config_factory.log_verbose(verbose)
|
||||||
|
self._store_config()
|
||||||
|
|
||||||
|
def set_data_path(self, path: str):
|
||||||
|
self._config_factory.data_path(path)
|
||||||
|
self._store_config()
|
||||||
|
|
||||||
|
def set_log_format(self, log_format: str):
|
||||||
|
self._config_factory.log_format(log_format)
|
||||||
|
self._store_config()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
logger.info("Clear workspace config.")
|
||||||
|
os.remove(self.workspace_config)
|
||||||
|
|
||||||
|
def _load_config(self):
|
||||||
|
try:
|
||||||
|
with open(self.workspace_config, "r") as f:
|
||||||
|
return json.loads(f.read())
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _store_config(self):
|
||||||
|
with open(self.workspace_config, "w") as f:
|
||||||
|
config = self._config_factory.get_config()
|
||||||
|
config_json = {
|
||||||
|
"log_verbose": config.log_verbose,
|
||||||
|
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
|
||||||
|
"DATA_PATH": config.DATA_PATH,
|
||||||
|
"LOG_FORMAT": config.LOG_FORMAT
|
||||||
|
}
|
||||||
|
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
config_workspace: ConfigWorkSpace = ConfigWorkSpace()
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
from chatchat.configs import ConfigBasicFactory, ConfigBasic
|
from pathlib import Path
|
||||||
|
|
||||||
|
from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
@ -18,3 +20,37 @@ def test_config_factory_def():
|
|||||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
||||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
||||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace():
|
||||||
|
config_workspace = ConfigWorkSpace()
|
||||||
|
assert config_workspace.get_config() is not None
|
||||||
|
base_root = os.path.join(Path(__file__).absolute().parent, "chatchat")
|
||||||
|
config_workspace.set_data_path(os.path.join(base_root, "data"))
|
||||||
|
config_workspace.set_log_verbose(True)
|
||||||
|
config_workspace.set_log_format(" %(message)s")
|
||||||
|
|
||||||
|
config: ConfigBasic = config_workspace.get_config()
|
||||||
|
assert config.log_verbose is True
|
||||||
|
assert config.DATA_PATH == os.path.join(base_root, "data")
|
||||||
|
assert config.IMG_DIR is not None
|
||||||
|
assert config.NLTK_DATA_PATH == os.path.join(base_root, "data", "nltk_data")
|
||||||
|
assert config.LOG_FORMAT == " %(message)s"
|
||||||
|
assert config.LOG_PATH == os.path.join(base_root, "data", "logs")
|
||||||
|
assert config.MEDIA_PATH == os.path.join(base_root, "data", "media")
|
||||||
|
|
||||||
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
||||||
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
||||||
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
||||||
|
config_workspace.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace_default():
|
||||||
|
from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH)
|
||||||
|
assert log_verbose is False
|
||||||
|
assert DATA_PATH is not None
|
||||||
|
assert IMG_DIR is not None
|
||||||
|
assert NLTK_DATA_PATH is not None
|
||||||
|
assert LOG_FORMAT is not None
|
||||||
|
assert LOG_PATH is not None
|
||||||
|
assert MEDIA_PATH is not None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user