From f76e484ff517bb60a1452b06688adb038e9a30f4 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 8 Jun 2024 16:33:47 +0800 Subject: [PATCH 1/9] =?UTF-8?q?=E6=B6=88=E9=99=A4=E8=AD=A6=E5=91=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- libs/chatchat-server/poetry.toml | 3 --- libs/chatchat-server/pyproject.toml | 2 +- libs/model-providers/poetry.toml | 3 --- poetry.toml | 4 ---- 4 files changed, 1 insertion(+), 11 deletions(-) diff --git a/libs/chatchat-server/poetry.toml b/libs/chatchat-server/poetry.toml index 3b673ef3..44b8767e 100644 --- a/libs/chatchat-server/poetry.toml +++ b/libs/chatchat-server/poetry.toml @@ -1,9 +1,6 @@ [virtualenvs] in-project = true -[installer] -modern-installation = false - [plugins] [plugins.pypi_mirror] url = "https://pypi.tuna.tsinghua.edu.cn/simple" diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index 3bb9870b..cf980bf5 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -261,4 +261,4 @@ location = ".env" [[tool.poetry.source]] name = "tsinghua" url = "https://pypi.tuna.tsinghua.edu.cn/simple/" -priority = "default" +priority = "primary" diff --git a/libs/model-providers/poetry.toml b/libs/model-providers/poetry.toml index 3b673ef3..44b8767e 100644 --- a/libs/model-providers/poetry.toml +++ b/libs/model-providers/poetry.toml @@ -1,9 +1,6 @@ [virtualenvs] in-project = true -[installer] -modern-installation = false - [plugins] [plugins.pypi_mirror] url = "https://pypi.tuna.tsinghua.edu.cn/simple" diff --git a/poetry.toml b/poetry.toml index 3b673ef3..68cda26b 100644 --- a/poetry.toml +++ b/poetry.toml @@ -1,9 +1,5 @@ [virtualenvs] in-project = true -[installer] -modern-installation = false - -[plugins] [plugins.pypi_mirror] url = "https://pypi.tuna.tsinghua.edu.cn/simple" From 84bafe97235fbb7f29b156c3ed4bfcca8ded295b Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 9 Jun 2024 19:59:54 +0800 Subject: [PATCH 2/9] =?UTF-8?q?=E7=94=A8=E6=88=B7=E5=B7=A5=E4=BD=9C?= =?UTF-8?q?=E7=A9=BA=E9=97=B4=E6=93=8D=E4=BD=9C=20(#4156)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 工作空间的配置预设,提供ConfigBasic建造方法产生实例。 该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等 工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。 注意:不存在则读取默认 提供了操作入口 指令` chatchat-config` 工作空间配置 options: ``` -h, --help show this help message and exit -v {true,false}, --verbose {true,false} 是否开启详细日志 -d DATA, --data DATA 数据存放路径 -f FORMAT, --format FORMAT 日志格式 --clear 清除配置 ``` --- .../chatchat/config_work_space.py | 47 ++++ .../chatchat/configs/__init__.py | 93 ++++++-- .../chatchat/configs/_basic_config.py | 219 ++++++++++++++---- libs/chatchat-server/pyproject.toml | 1 + libs/chatchat-server/tests/conftest.py | 90 +++++++ .../tests/unit_tests/config/test_config.py | 56 +++++ 6 files changed, 436 insertions(+), 70 deletions(-) create mode 100644 libs/chatchat-server/chatchat/config_work_space.py create mode 100644 libs/chatchat-server/tests/conftest.py create mode 100644 libs/chatchat-server/tests/unit_tests/config/test_config.py diff --git a/libs/chatchat-server/chatchat/config_work_space.py b/libs/chatchat-server/chatchat/config_work_space.py new file mode 100644 index 00000000..6b3f2907 --- /dev/null +++ b/libs/chatchat-server/chatchat/config_work_space.py @@ -0,0 +1,47 @@ +from chatchat.configs import config_workspace as workspace + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="指令` chatchat-config` 工作空间配置") + # 只能选择true或false + parser.add_argument( + "-v", + "--verbose", + choices=["true", "false"], + help="是否开启详细日志" + ) + parser.add_argument( + "-d", + "--data", + help="数据存放路径" + ) + parser.add_argument( + "-f", + "--format", + help="日志格式" + ) + parser.add_argument( + "--clear", + action="store_true", + help="清除配置" + ) + args = parser.parse_args() + + if args.verbose: + if args.verbose.lower() == "true": + workspace.set_log_verbose(True) + else: + workspace.set_log_verbose(False) + if args.data: + workspace.set_data_path(args.data) + if args.format: + workspace.set_log_format(args.format) + if args.clear: + workspace.clear() + print(workspace.get_config()) + + +if __name__ == "__main__": + main() diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index bca795be..c1191b08 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -1,8 +1,10 @@ import importlib import importlib.util import os +from pathlib import Path from typing import Dict, Any +import json import logging logger = logging.getLogger() @@ -38,7 +40,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict: ) user_import = False if user_import: - # Dynamic loading {config}.py file py_path = os.path.join(user_config_path, import_config_mod + ".py") spec = importlib.util.spec_from_file_location( @@ -69,7 +70,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict: ) raise RuntimeError(f"Failed to load user config from {user_config_path}") # 当前文件路径 - py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py") + py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py") spec = importlib.util.spec_from_file_location(f"*", py_path) @@ -95,75 +96,108 @@ CONFIG_IMPORTS = { } +def _import_ConfigBasic() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigBasic = load_mod(basic_config_load.get("module"), "ConfigBasic") + + return ConfigBasic + + +def _import_ConfigBasicFactory() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigBasicFactory = load_mod(basic_config_load.get("module"), "ConfigBasicFactory") + + return ConfigBasicFactory + + +def _import_ConfigWorkSpace() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace") + + return ConfigWorkSpace + + +def _import_config_workspace() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace + def _import_log_verbose() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - log_verbose = load_mod(basic_config_load.get("module"), "log_verbose") - - return log_verbose + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace.get_config().log_verbose def _import_chatchat_root() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - CHATCHAT_ROOT = load_mod(basic_config_load.get("module"), "CHATCHAT_ROOT") + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return CHATCHAT_ROOT + return config_workspace.get_config().CHATCHAT_ROOT def _import_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - DATA_PATH = load_mod(basic_config_load.get("module"), "DATA_PATH") - return DATA_PATH + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace.get_config().DATA_PATH + def _import_img_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - IMG_DIR = load_mod(basic_config_load.get("module"), "IMG_DIR") - return IMG_DIR + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().IMG_DIR def _import_nltk_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - NLTK_DATA_PATH = load_mod(basic_config_load.get("module"), "NLTK_DATA_PATH") + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return NLTK_DATA_PATH + return config_workspace.get_config().NLTK_DATA_PATH def _import_log_format() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - LOG_FORMAT = load_mod(basic_config_load.get("module"), "LOG_FORMAT") - return LOG_FORMAT + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().LOG_FORMAT def _import_log_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - LOG_PATH = load_mod(basic_config_load.get("module"), "LOG_PATH") - return LOG_PATH + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().LOG_PATH def _import_media_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - MEDIA_PATH = load_mod(basic_config_load.get("module"), "MEDIA_PATH") - return MEDIA_PATH + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().MEDIA_PATH def _import_base_temp_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - BASE_TEMP_DIR = load_mod(basic_config_load.get("module"), "BASE_TEMP_DIR") - - return BASE_TEMP_DIR + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace.get_config().BASE_TEMP_DIR def _import_default_knowledge_base() -> Any: @@ -285,6 +319,7 @@ def _import_db_root_path() -> Any: return DB_ROOT_PATH + def _import_sqlalchemy_database_uri() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") @@ -478,7 +513,15 @@ def _import_api_server() -> Any: def __getattr__(name: str) -> Any: - if name == "log_verbose": + if name == "ConfigBasic": + return _import_ConfigBasic() + elif name == "ConfigBasicFactory": + return _import_ConfigBasicFactory() + elif name == "ConfigWorkSpace": + return _import_ConfigWorkSpace() + elif name == "config_workspace": + return _import_config_workspace() + elif name == "log_verbose": return _import_log_verbose() elif name == "CHATCHAT_ROOT": return _import_chatchat_root() @@ -578,6 +621,7 @@ VERSION = "v0.3.0-preview" __all__ = [ "VERSION", + "config_workspace", "log_verbose", "CHATCHAT_ROOT", "DATA_PATH", @@ -626,5 +670,8 @@ __all__ = [ "WEBUI_SERVER", "API_SERVER", + "ConfigBasic", + "ConfigBasicFactory", + "ConfigWorkSpace", ] diff --git a/libs/chatchat-server/chatchat/configs/_basic_config.py b/libs/chatchat-server/chatchat/configs/_basic_config.py index 46eec939..61108bde 100644 --- a/libs/chatchat-server/chatchat/configs/_basic_config.py +++ b/libs/chatchat-server/chatchat/configs/_basic_config.py @@ -1,55 +1,180 @@ -import logging import os +import json from pathlib import Path +import logging -import langchain - - -# 是否显示详细日志 -log_verbose = False -langchain.verbose = False - -# 通常情况下不需要更改以下内容 - -# chatchat 项目根目录 -CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) - -# 用户数据根目录 -DATA_PATH = os.path.join(CHATCHAT_ROOT, "data") -if not os.path.exists(DATA_PATH): - os.mkdir(DATA_PATH) - -# 项目相关图片 -IMG_DIR = os.path.join(CHATCHAT_ROOT, "img") -if not os.path.exists(IMG_DIR): - os.mkdir(IMG_DIR) - -# nltk 模型存储路径 -NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data") -import nltk -nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path - -# 日志格式 -LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" logger = logging.getLogger() -logger.setLevel(logging.INFO) -logging.basicConfig(format=LOG_FORMAT) -# 日志存储路径 -LOG_PATH = os.path.join(DATA_PATH, "logs") -if not os.path.exists(LOG_PATH): - os.mkdir(LOG_PATH) +class ConfigBasic: + log_verbose: bool + """是否开启日志详细信息""" + CHATCHAT_ROOT: str + """项目根目录""" + DATA_PATH: str + """用户数据根目录""" + IMG_DIR: str + """项目相关图片""" + NLTK_DATA_PATH: str + """nltk 模型存储路径""" + LOG_FORMAT: str + """日志格式""" + LOG_PATH: str + """日志存储路径""" + MEDIA_PATH: str + """模型生成内容(图片、视频、音频等)保存位置""" + BASE_TEMP_DIR: str + """临时文件目录,主要用于文件对话""" -# 模型生成内容(图片、视频、音频等)保存位置 -MEDIA_PATH = os.path.join(DATA_PATH, "media") -if not os.path.exists(MEDIA_PATH): - os.mkdir(MEDIA_PATH) - os.mkdir(os.path.join(MEDIA_PATH, "image")) - os.mkdir(os.path.join(MEDIA_PATH, "audio")) - os.mkdir(os.path.join(MEDIA_PATH, "video")) + def __str__(self): + return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})" -# 临时文件目录,主要用于文件对话 -BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp") -if not os.path.exists(BASE_TEMP_DIR): - os.mkdir(BASE_TEMP_DIR) + +class ConfigBasicFactory: + """Basic config for ChatChat """ + + def __init__(self): + # 日志格式 + self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" + logging.basicConfig(format=self.LOG_FORMAT) + self.LOG_VERBOSE = False + self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) + # 用户数据根目录 + self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") + self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") + if not os.path.exists(self._DATA_PATH): + os.makedirs(self._DATA_PATH, exist_ok=True) + + self._init_data_dir() + + # 项目相关图片 + self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img") + if not os.path.exists(self.IMG_DIR): + os.makedirs(self.IMG_DIR, exist_ok=True) + + def log_verbose(self, verbose: bool): + self.LOG_VERBOSE = verbose + + def data_path(self, path: str): + self.DATA_PATH = path + if not os.path.exists(self.DATA_PATH): + os.makedirs(self.DATA_PATH, exist_ok=True) + # 复制_DATA_PATH数据到DATA_PATH + if self._DATA_PATH != self.DATA_PATH: + os.system(f"cp -r {self._DATA_PATH}/* {self.DATA_PATH}") + + self._init_data_dir() + + def log_format(self, log_format: str): + self.LOG_FORMAT = log_format + logging.basicConfig(format=self.LOG_FORMAT) + + def _init_data_dir(self): + logger.info(f"Init data dir: {self.DATA_PATH}") + # nltk 模型存储路径 + self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data") + import nltk + nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path + # 日志存储路径 + self.LOG_PATH = os.path.join(self.DATA_PATH, "logs") + if not os.path.exists(self.LOG_PATH): + os.makedirs(self.LOG_PATH, exist_ok=True) + + # 模型生成内容(图片、视频、音频等)保存位置 + self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media") + if not os.path.exists(self.MEDIA_PATH): + os.makedirs(self.MEDIA_PATH, exist_ok=True) + os.makedirs(os.path.join(self.MEDIA_PATH, "image"), exist_ok=True) + os.makedirs(os.path.join(self.MEDIA_PATH, "audio"), exist_ok=True) + os.makedirs(os.path.join(self.MEDIA_PATH, "video"), exist_ok=True) + + # 临时文件目录,主要用于文件对话 + self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp") + if not os.path.exists(self.BASE_TEMP_DIR): + os.makedirs(self.BASE_TEMP_DIR, exist_ok=True) + + logger.info(f"Init data dir: {self.DATA_PATH} success.") + + def get_config(self) -> ConfigBasic: + config = ConfigBasic() + config.log_verbose = self.LOG_VERBOSE + config.CHATCHAT_ROOT = self.CHATCHAT_ROOT + config.DATA_PATH = self.DATA_PATH + config.IMG_DIR = self.IMG_DIR + config.NLTK_DATA_PATH = self.NLTK_DATA_PATH + config.LOG_FORMAT = self.LOG_FORMAT + config.LOG_PATH = self.LOG_PATH + config.MEDIA_PATH = self.MEDIA_PATH + config.BASE_TEMP_DIR = self.BASE_TEMP_DIR + return config + + +class ConfigWorkSpace: + """ + 工作空间的配置预设,提供ConfigBasic建造方法产生实例。 + 该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等 + 工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。 + 注意:不存在则读取默认 + """ + _config_factory: ConfigBasicFactory = ConfigBasicFactory() + + def __init__(self): + self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace") + if not os.path.exists(self.workspace): + os.makedirs(self.workspace, exist_ok=True) + self.workspace_config = os.path.join(self.workspace, "workspace_config.json") + # 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化 + + config_json = self._load_config() + + if config_json: + + _config_factory = ConfigBasicFactory() + if config_json.get("log_verbose"): + _config_factory.log_verbose(config_json.get("log_verbose")) + if config_json.get("DATA_PATH"): + _config_factory.data_path(config_json.get("DATA_PATH")) + if config_json.get("LOG_FORMAT"): + _config_factory.log_format(config_json.get("LOG_FORMAT")) + + self._config_factory = _config_factory + + def get_config(self) -> ConfigBasic: + return self._config_factory.get_config() + + def set_log_verbose(self, verbose: bool): + self._config_factory.log_verbose(verbose) + self._store_config() + + def set_data_path(self, path: str): + self._config_factory.data_path(path) + self._store_config() + + def set_log_format(self, log_format: str): + self._config_factory.log_format(log_format) + self._store_config() + + def clear(self): + logger.info("Clear workspace config.") + os.remove(self.workspace_config) + + def _load_config(self): + try: + with open(self.workspace_config, "r") as f: + return json.loads(f.read()) + except FileNotFoundError: + return None + + def _store_config(self): + with open(self.workspace_config, "w") as f: + config = self._config_factory.get_config() + config_json = { + "log_verbose": config.log_verbose, + "CHATCHAT_ROOT": config.CHATCHAT_ROOT, + "DATA_PATH": config.DATA_PATH, + "LOG_FORMAT": config.LOG_FORMAT + } + f.write(json.dumps(config_json, indent=4, ensure_ascii=False)) + + +config_workspace: ConfigWorkSpace = ConfigWorkSpace() diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index f9adff8d..e105db24 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -11,6 +11,7 @@ packages = [ [tool.poetry.scripts] chatchat = 'chatchat.startup:main' chatchat-kb = 'chatchat.init_database:main' +chatchat-config = 'chatchat.config_work_space:main' [tool.poetry.dependencies] python = ">=3.8.1,<3.12,!=3.9.7" diff --git a/libs/chatchat-server/tests/conftest.py b/libs/chatchat-server/tests/conftest.py new file mode 100644 index 00000000..031384e8 --- /dev/null +++ b/libs/chatchat-server/tests/conftest.py @@ -0,0 +1,90 @@ +"""Configuration for unit tests.""" +import logging +from importlib import util +from typing import Dict, List, Sequence + +import pytest +from pytest import Config, Function, Parser + + +def pytest_addoption(parser: Parser) -> None: + """Add custom command line options to pytest.""" + parser.addoption( + "--only-extended", + action="store_true", + help="Only run extended tests. Does not allow skipping any extended tests.", + ) + parser.addoption( + "--only-core", + action="store_true", + help="Only run core tests. Never runs any extended tests.", + ) + + +def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: + """Add implementations for handling custom markers. + + At the moment, this adds support for a custom `requires` marker. + + The `requires` marker is used to denote tests that require one or more packages + to be installed to run. If the package is not installed, the test is skipped. + + The `requires` marker syntax is: + + .. code-block:: python + + @pytest.mark.requires("package1", "package2") + def test_something(): + ... + """ + # Mapping from the name of a package to whether it is installed or not. + # Used to avoid repeated calls to `util.find_spec` + required_pkgs_info: Dict[str, bool] = {} + + only_extended = config.getoption("--only-extended") or False + only_core = config.getoption("--only-core") or False + + if only_extended and only_core: + raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") + + for item in items: + requires_marker = item.get_closest_marker("requires") + if requires_marker is not None: + if only_core: + item.add_marker(pytest.mark.skip(reason="Skipping not a core test.")) + continue + + # Iterate through the list of required packages + required_pkgs = requires_marker.args + for pkg in required_pkgs: + # If we haven't yet checked whether the pkg is installed + # let's check it and store the result. + if pkg not in required_pkgs_info: + try: + installed = util.find_spec(pkg) is not None + except Exception: + installed = False + required_pkgs_info[pkg] = installed + + if not required_pkgs_info[pkg]: + if only_extended: + pytest.fail( + f"Package `{pkg}` is not installed but is required for " + f"extended tests. Please install the given package and " + f"try again.", + ) + + else: + # If the package is not installed, we immediately break + # and mark the test as skipped. + item.add_marker( + pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") + ) + break + else: + if only_extended: + item.add_marker( + pytest.mark.skip(reason="Skipping not an extended test.") + ) + + diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py new file mode 100644 index 00000000..96fbd5dc --- /dev/null +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -0,0 +1,56 @@ +from pathlib import Path + +from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace +import os + + +def test_config_factory_def(): + test_config_factory = ConfigBasicFactory() + config: ConfigBasic = test_config_factory.get_config() + assert config is not None + assert config.log_verbose is False + assert config.CHATCHAT_ROOT is not None + assert config.DATA_PATH is not None + assert config.IMG_DIR is not None + assert config.NLTK_DATA_PATH is not None + assert config.LOG_FORMAT is not None + assert config.LOG_PATH is not None + assert config.MEDIA_PATH is not None + + assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "video")) + + +def test_workspace(): + config_workspace = ConfigWorkSpace() + assert config_workspace.get_config() is not None + base_root = os.path.join(Path(__file__).absolute().parent, "chatchat") + config_workspace.set_data_path(os.path.join(base_root, "data")) + config_workspace.set_log_verbose(True) + config_workspace.set_log_format(" %(message)s") + + config: ConfigBasic = config_workspace.get_config() + assert config.log_verbose is True + assert config.DATA_PATH == os.path.join(base_root, "data") + assert config.IMG_DIR is not None + assert config.NLTK_DATA_PATH == os.path.join(base_root, "data", "nltk_data") + assert config.LOG_FORMAT == " %(message)s" + assert config.LOG_PATH == os.path.join(base_root, "data", "logs") + assert config.MEDIA_PATH == os.path.join(base_root, "data", "media") + + assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "video")) + config_workspace.clear() + + +def test_workspace_default(): + from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH) + assert log_verbose is False + assert DATA_PATH is not None + assert IMG_DIR is not None + assert NLTK_DATA_PATH is not None + assert LOG_FORMAT is not None + assert LOG_PATH is not None + assert MEDIA_PATH is not None From 6019af4f331ead6f6202155b77d0abb4fc11b103 Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Sun, 9 Jun 2024 21:22:27 +0800 Subject: [PATCH 3/9] =?UTF-8?q?=E9=85=8D=E7=BD=AE=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- libs/chatchat-server/chatchat/configs/_kb_config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/libs/chatchat-server/chatchat/configs/_kb_config.py b/libs/chatchat-server/chatchat/configs/_kb_config.py index 9fed2715..40b4626b 100644 --- a/libs/chatchat-server/chatchat/configs/_kb_config.py +++ b/libs/chatchat-server/chatchat/configs/_kb_config.py @@ -1,11 +1,13 @@ import os from pathlib import Path -# chatchat 项目根目录 -CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) +import sys +sys.path.append(str(Path(__file__).parent)) + +from _basic_config import config_workspace # 用户数据根目录 -DATA_PATH = os.path.join(CHATCHAT_ROOT, "data") +DATA_PATH = config_workspace.get_config().DATA_PATH # 默认使用的知识库 DEFAULT_KNOWLEDGE_BASE = "samples" From b110fcd01bdfc2ede931ba8e9d75c8db76ef7e87 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Sun, 9 Jun 2024 22:51:48 +0800 Subject: [PATCH 4/9] fix faiss_cache bug --- .../chatchat/server/knowledge_base/kb_cache/faiss_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py index ec8adba2..36d21162 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py @@ -110,7 +110,8 @@ class KBFaissPool(_FaissPool): if os.path.isfile(os.path.join(vs_path, "index.faiss")): embeddings = get_Embeddings(embed_model=embed_model) - vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True, + allow_dangerous_deserialization=True) elif create: # create an empty vector store if not os.path.exists(vs_path): From b56283eb01699f03c228f466234e30e3d6550392 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Mon, 10 Jun 2024 00:36:03 +0800 Subject: [PATCH 5/9] Feature(File RAG): add file_rag in chatchat-server, add ensemble retriever and vectorstore retriever. --- .../search_local_knowledgebase.py | 7 +-- .../chatchat/server/chat/file_chat.py | 6 +-- .../chatchat/server/file_rag/__init__.py | 0 .../document_loaders/FilteredCSVloader.py | 0 .../document_loaders/__init__.py | 0 .../document_loaders/mydocloader.py | 0 .../document_loaders/myimgloader.py | 2 +- .../document_loaders/mypdfloader.py | 2 +- .../document_loaders/mypptloader.py | 0 .../{ => file_rag}/document_loaders/ocr.py | 0 .../server/file_rag/retrievers/__init__.py | 3 ++ .../server/file_rag/retrievers/base.py | 24 ++++++++++ .../server/file_rag/retrievers/ensemble.py | 47 +++++++++++++++++++ .../server/file_rag/retrievers/vectorstore.py | 33 +++++++++++++ .../{ => file_rag}/text_splitter/__init__.py | 0 .../text_splitter/ali_text_splitter.py | 0 .../chinese_recursive_text_splitter.py | 0 .../text_splitter/chinese_text_splitter.py | 0 .../text_splitter/zh_title_enhance.py | 0 .../chatchat/server/file_rag/utils.py | 13 +++++ .../server/knowledge_base/kb_doc_api.py | 3 +- .../kb_service/chromadb_kb_service.py | 12 +++-- .../kb_service/es_kb_service.py | 9 +++- .../kb_service/faiss_kb_service.py | 13 +++-- .../kb_service/milvus_kb_service.py | 16 +++++-- .../kb_service/pg_kb_service.py | 12 +++-- .../kb_service/zilliz_kb_service.py | 15 +++--- .../chatchat/server/knowledge_base/utils.py | 2 +- .../chatchat/webui_pages/dialogue/dialogue.py | 24 +++++----- libs/chatchat-server/pyproject.toml | 2 + 30 files changed, 198 insertions(+), 47 deletions(-) create mode 100644 libs/chatchat-server/chatchat/server/file_rag/__init__.py rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/FilteredCSVloader.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/__init__.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/mydocloader.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/myimgloader.py (92%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/mypdfloader.py (98%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/mypptloader.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/document_loaders/ocr.py (100%) create mode 100644 libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py create mode 100644 libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py create mode 100644 libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py create mode 100644 libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/__init__.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/ali_text_splitter.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/chinese_recursive_text_splitter.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/chinese_text_splitter.py (100%) rename libs/chatchat-server/chatchat/server/{ => file_rag}/text_splitter/zh_title_enhance.py (100%) create mode 100644 libs/chatchat-server/chatchat/server/file_rag/utils.py diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index e7524f2a..5f8fde72 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -1,13 +1,14 @@ from urllib.parse import urlencode from chatchat.server.utils import get_tool_config from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool, BaseToolOutput +from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput from chatchat.server.knowledge_base.kb_api import list_kbs from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId from chatchat.configs import KB_INFO -template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." +template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on " + "this knowledge use this tool. The 'database' should be one of the above [{key}].") KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) template_knowledge = template.format(KB_info=KB_info_str, key="samples") @@ -49,7 +50,7 @@ def search_local_knowledgebase( database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]), query: str = Field(description="Query for Knowledge Search"), ): - '''''' + """""" tool_config = get_tool_config("search_local_knowledgebase") ret = search_knowledgebase(query=query, database=database, config=tool_config) return KBToolOutput(ret, database=database) diff --git a/libs/chatchat-server/chatchat/server/chat/file_chat.py b/libs/chatchat-server/chatchat/server/chat/file_chat.py index f2a8e67a..a0e67d0d 100644 --- a/libs/chatchat-server/chatchat/server/chat/file_chat.py +++ b/libs/chatchat-server/chatchat/server/chat/file_chat.py @@ -63,10 +63,10 @@ def upload_temp_docs( chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), ) -> BaseResponse: - ''' + """ 将文件保存到临时目录,并进行向量化。 返回临时目录名称作为ID,同时也是临时向量库的ID。 - ''' + """ if prev_id is not None: memo_faiss_pool.pop(prev_id) @@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= docs = [x[0] for x in docs] context = "\n".join([doc.page_content for doc in docs]) - if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板 + if len(docs) == 0: # 如果没有找到相关文档,使用Empty模板 prompt_template = get_prompt_template("knowledge_base_chat", "empty") else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) diff --git a/libs/chatchat-server/chatchat/server/file_rag/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/chatchat-server/chatchat/server/document_loaders/FilteredCSVloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/FilteredCSVloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/__init__.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/mydocloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/mydocloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/myimgloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py similarity index 92% rename from libs/chatchat-server/chatchat/server/document_loaders/myimgloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py index 6b195cce..c6fda01e 100644 --- a/libs/chatchat-server/chatchat/server/document_loaders/myimgloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py @@ -1,6 +1,6 @@ from typing import List from langchain_community.document_loaders.unstructured import UnstructuredFileLoader -from chatchat.server.document_loaders.ocr import get_ocr +from chatchat.server.file_rag.document_loaders.ocr import get_ocr class RapidOCRLoader(UnstructuredFileLoader): diff --git a/libs/chatchat-server/chatchat/server/document_loaders/mypdfloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py similarity index 98% rename from libs/chatchat-server/chatchat/server/document_loaders/mypdfloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py index 9e8796a4..c6a178f8 100644 --- a/libs/chatchat-server/chatchat/server/document_loaders/mypdfloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py @@ -4,7 +4,7 @@ import cv2 from PIL import Image import numpy as np from chatchat.configs import PDF_OCR_THRESHOLD -from chatchat.server.document_loaders.ocr import get_ocr +from chatchat.server.file_rag.document_loaders.ocr import get_ocr import tqdm diff --git a/libs/chatchat-server/chatchat/server/document_loaders/mypptloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/mypptloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/ocr.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/ocr.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py new file mode 100644 index 00000000..2cf3617f --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py @@ -0,0 +1,3 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService +from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py new file mode 100644 index 00000000..7e4d0646 --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py @@ -0,0 +1,24 @@ +from langchain.vectorstores import VectorStore +from abc import ABCMeta, abstractmethod + + +class BaseRetrieverService(metaclass=ABCMeta): + def __init__(self, **kwargs): + self.do_init(**kwargs) + + @abstractmethod + def do_init(self, **kwargs): + pass + + + @abstractmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + pass + + @abstractmethod + def get_relevant_documents(self, query: str): + pass diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py new file mode 100644 index 00000000..cb09b633 --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py @@ -0,0 +1,47 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from langchain.vectorstores import VectorStore +from langchain_core.retrievers import BaseRetriever +from langchain_community.retrievers import BM25Retriever +from langchain.retrievers import EnsembleRetriever + + +class EnsembleRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = retriever + + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + faiss_retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={ + "score_threshold": score_threshold, + "k": top_k + } + ) + # TODO: 换个不用torch的实现方式 + from cutword.cutword import Cutter + cutter = Cutter() + docs = list(vectorstore.docstore._dict.values()) + bm25_retriever = BM25Retriever.from_documents( + docs, + preprocess_func=cutter.cutword + ) + bm25_retriever.k = top_k + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] + ) + return EnsembleRetrieverService(retriever=ensemble_retriever) + + def get_relevant_documents(self, query: str): + return self.retriever.get_relevant_documents(query)[:self.top_k] diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py new file mode 100644 index 00000000..b6d382fa --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py @@ -0,0 +1,33 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from langchain.vectorstores import VectorStore +from langchain_core.retrievers import BaseRetriever + + +class VectorstoreRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = retriever + + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={ + "score_threshold": score_threshold, + "k": top_k + } + ) + return VectorstoreRetrieverService(retriever=retriever) + + def get_relevant_documents(self, query: str): + return self.retriever.get_relevant_documents(query)[:self.top_k] diff --git a/libs/chatchat-server/chatchat/server/text_splitter/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/__init__.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/ali_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/ali_text_splitter.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/chinese_recursive_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/chinese_recursive_text_splitter.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/chinese_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/chinese_text_splitter.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/zh_title_enhance.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/zh_title_enhance.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py diff --git a/libs/chatchat-server/chatchat/server/file_rag/utils.py b/libs/chatchat-server/chatchat/server/file_rag/utils.py new file mode 100644 index 00000000..ddf64e3d --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/utils.py @@ -0,0 +1,13 @@ +from chatchat.server.file_rag.retrievers import ( + BaseRetrieverService, + VectorstoreRetrieverService, + EnsembleRetrieverService, +) + +Retrivals = { + "vectorstore": VectorstoreRetrieverService, + "ensemble": EnsembleRetrieverService, +} + +def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService: + return Retrivals[type] \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py index 0e92c091..3e40786d 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py @@ -37,7 +37,8 @@ def search_docs( if kb is not None: if query: docs = kb.search_docs(query, top_k, score_threshold) - data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs] elif file_name or metadata: data = kb.list_docs(file_name=file_name, metadata=metadata) for d in data: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py index 0834c87d..c6c46622 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py @@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.utils import get_Embeddings +from chatchat.server.file_rag.utils import get_Retriever def _get_result_to_documents(get_result: GetResult) -> List[Document]: @@ -75,10 +76,13 @@ class ChromaKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[ Tuple[Document, float]]: - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k) - return _results_to_docs_and_scores(query_result) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.collection, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: doc_infos = [] diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py index 19813bf1..aef63710 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py @@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings from elasticsearch import Elasticsearch, BadRequestError from chatchat.configs import kbs_config, KB_ROOT_PATH +from chatchat.server.file_rag.utils import get_Retriever import logging @@ -107,8 +108,12 @@ class ESKBService(KBService): def do_search(self, query:str, top_k: int, score_threshold: float): # 文本相似性检索 - docs = self.db.similarity_search_with_score(query=query, - k=top_k) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.db, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) return docs def get_doc_by_ids(self, ids: List[str]) -> List[Document]: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py index 95f7cd64..52738ae8 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py @@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path -from chatchat.server.utils import get_Embeddings from langchain.docstore.document import Document -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Tuple +from chatchat.server.file_rag.utils import get_Retriever class FaissKBService(KBService): @@ -62,10 +62,13 @@ class FaissKBService(KBService): top_k: int, score_threshold: float = SCORE_THRESHOLD, ) -> List[Tuple[Document, float]]: - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) with self.load_vector_store().acquire() as vs: - docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) + retriever = get_Retriever("ensemble").from_vectorstore( + vs, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) return docs def do_add_doc(self, diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py index ab0a77e9..8eddb5f4 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py @@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ score_threshold_process from chatchat.server.knowledge_base.utils import KnowledgeFile -from chatchat.server.utils import get_Embeddings +from chatchat.server.file_rag.utils import get_Retriever class MilvusKBService(KBService): @@ -67,10 +67,16 @@ class MilvusKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_milvus() - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) - return score_threshold_process(score_threshold, top_k, docs) + # embed_func = get_Embeddings(self.embed_model) + # embeddings = embed_func.embed_query(query) + # docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.milvus, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: for doc in docs: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py index 8c3a0cf6..473c7f30 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py @@ -15,6 +15,7 @@ import shutil import sqlalchemy from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session +from chatchat.server.file_rag.utils import get_Retriever class PGKBService(KBService): @@ -60,10 +61,13 @@ class PGKBService(KBService): shutil.rmtree(self.kb_path) def do_search(self, query: str, top_k: int, score_threshold: float): - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) - return score_threshold_process(score_threshold, top_k, docs) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.pg_vector, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: ids = self.pg_vector.add_documents(docs) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py index 51e21b10..336eaa48 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -1,5 +1,4 @@ -from typing import List, Dict, Optional -from langchain.embeddings.base import Embeddings +from typing import List, Dict from langchain.schema import Document from langchain.vectorstores import Zilliz from chatchat.configs import kbs_config @@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV score_threshold_process from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings +from chatchat.server.file_rag.utils import get_Retriever class ZillizKBService(KBService): @@ -60,10 +60,13 @@ class ZillizKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_zilliz() - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k) - return score_threshold_process(score_threshold, top_k, docs) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.zilliz, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: for doc in docs: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py index c5dd442b..a03393dd 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py @@ -10,7 +10,7 @@ from chatchat.configs import ( TEXT_SPLITTER_NAME, ) import importlib -from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance +from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance import langchain_community.document_loaders from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter diff --git a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index cdd40aa6..4ec3a80b 100644 --- a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -24,28 +24,28 @@ chat_box = ChatBox( def save_session(): - '''save session state to chat context''' + """save session state to chat context""" chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"]) def restore_session(): - '''restore sesstion state from chat context''' + """restore sesstion state from chat context""" chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"]) def rerun(): - ''' + """ save chat context before rerun - ''' + """ save_session() st.rerun() def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: - ''' + """ 返回消息历史。 content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要 - ''' + """ def filter(msg): content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]] @@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) -> @st.cache_data def upload_temp_docs(files, _api: ApiRequest) -> str: - ''' + """ 将文件上传到临时目录,用于文件对话 返回临时向量库ID - ''' + """ return _api.upload_temp_docs(files).get("data", {}).get("id") @@ -157,11 +157,13 @@ def dialogue_page( tools = list_tools(api) tool_names = ["None"] + list(tools) if use_agent: - # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools") + # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", + # check_all=True, key="selected_tools") selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], key="selected_tools") else: - # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool") + # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", + # key="selected_tool") selected_tool = st.selectbox("选择工具", tool_names, format_func=lambda x: tools.get(x, {"title": "None"})["title"], key="selected_tool") @@ -338,7 +340,7 @@ def dialogue_page( elif d.status == AgentStatus.agent_finish: text = d.choices[0].delta.content or "" chat_box.update_msg(text.replace("\n", "\n\n")) - elif d.status == None: # not agent chat + elif d.status is None: # not agent chat if getattr(d, "is_ref", False): chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料")) diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index e105db24..136b8b1b 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -29,6 +29,8 @@ unstructured = "~0.11.0" python-magic-bin = {version = "*", platform = "win32"} SQLAlchemy = "~2.0.25" faiss-cpu = "~1.7.4" +cutword = "0.1.0" +rank_bm25 = "0.2.2" # accelerate = "~0.24.1" # spacy = "~3.7.2" PyMuPDF = "~1.23.16" From 0abb5a1f9d29a70a533be6534f71e40cf6abf02b Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Mon, 10 Jun 2024 00:36:46 +0800 Subject: [PATCH 6/9] Feature(File RAG): add file_rag in chatchat-server, add ensemble retriever and vectorstore retriever. --- libs/chatchat-server/chatchat/configs/_model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/chatchat-server/chatchat/configs/_model_config.py b/libs/chatchat-server/chatchat/configs/_model_config.py index 35625d56..afb0e8ea 100644 --- a/libs/chatchat-server/chatchat/configs/_model_config.py +++ b/libs/chatchat-server/chatchat/configs/_model_config.py @@ -149,7 +149,7 @@ TOOL_CONFIG = { "search_local_knowledgebase": { "use": False, "top_k": 3, - "score_threshold": 1, + "score_threshold": 1.0, "conclude_prompt": { "with_result": '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' From d71c9b0a27304fadfb4c69f308c61efdf2af2fb3 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Mon, 10 Jun 2024 10:11:31 +0800 Subject: [PATCH 7/9] fix xinference manager bug --- tools/model_loaders/xinference_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/model_loaders/xinference_manager.py b/tools/model_loaders/xinference_manager.py index 650d7cc2..d628fcb0 100644 --- a/tools/model_loaders/xinference_manager.py +++ b/tools/model_loaders/xinference_manager.py @@ -137,7 +137,7 @@ model_format = None model_quant = None if model_type == "LLM": - cur_family = xf_llm.LLMFamilyV1.model_validate(cur_reg) + cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg) cur_spec = None model_formats = [] for spec in cur_reg["model_specs"]: From 1987063a76a14ef5cfa745baaa51cb8f64c504ec Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Mon, 10 Jun 2024 16:33:13 +0800 Subject: [PATCH 8/9] Fix(File RAG): use jieba instead of cutword --- .../chatchat/server/file_rag/retrievers/ensemble.py | 7 ++++--- libs/chatchat-server/pyproject.toml | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py index cb09b633..5d6b17a6 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py @@ -30,12 +30,13 @@ class EnsembleRetrieverService(BaseRetrieverService): } ) # TODO: 换个不用torch的实现方式 - from cutword.cutword import Cutter - cutter = Cutter() + # from cutword.cutword import Cutter + import jieba + # cutter = Cutter() docs = list(vectorstore.docstore._dict.values()) bm25_retriever = BM25Retriever.from_documents( docs, - preprocess_func=cutter.cutword + preprocess_func=jieba.lcut_for_search, ) bm25_retriever.k = top_k ensemble_retriever = EnsembleRetriever( diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index 136b8b1b..85646274 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -29,7 +29,8 @@ unstructured = "~0.11.0" python-magic-bin = {version = "*", platform = "win32"} SQLAlchemy = "~2.0.25" faiss-cpu = "~1.7.4" -cutword = "0.1.0" +#cutword = "0.1.0" +jieba = "0.42.1" rank_bm25 = "0.2.2" # accelerate = "~0.24.1" # spacy = "~3.7.2" From 360fc012663f89f7b313ba272a63fa64a85b1b5d Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Mon, 10 Jun 2024 16:33:54 +0800 Subject: [PATCH 9/9] Fix(File RAG): update kb_doc_api.py --- .../chatchat/server/knowledge_base/kb_doc_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py index 3e40786d..e73627a1 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py @@ -28,7 +28,7 @@ def search_docs( description="知识库匹配相关度阈值,取值范围在0-1之间," "SCORE越小,相关度越高," "取到1相当于不筛选,建议设置在0.5左右", - ge=0, le=1), + ge=0.0, le=1.0), file_name: str = Body("", description="文件名称,支持 sql 通配符"), metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"), ) -> List[Dict]: