diff --git a/.env b/.env new file mode 100644 index 00000000..f78637c0 --- /dev/null +++ b/.env @@ -0,0 +1 @@ +PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring diff --git a/docs/contributing/code.md b/docs/contributing/code.md index 24599b52..c9d92bd7 100644 --- a/docs/contributing/code.md +++ b/docs/contributing/code.md @@ -15,9 +15,11 @@ Install Poetry: [documentation on how to install it.](https://python-poetry.org/docs/#installing-with-pipx) +> 友情提示 不想安装pipx可以用pip安装poetry,(Tips:如果你没有其它poetry的项目 > 注意: 如果您使用 Conda 或 Pyenv 作为您的环境/包管理器,在安装Poetry之后, > 使用如下命令使 Poetry 使用 virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`) + #### 本地开发环境安装 - 选择主项目目录 diff --git a/libs/chatchat-server/.env b/libs/chatchat-server/.env new file mode 100644 index 00000000..f78637c0 --- /dev/null +++ b/libs/chatchat-server/.env @@ -0,0 +1 @@ +PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring diff --git a/libs/chatchat-server/README.md b/libs/chatchat-server/README.md index 4e0786c8..f2ce0940 100644 --- a/libs/chatchat-server/README.md +++ b/libs/chatchat-server/README.md @@ -29,6 +29,11 @@ vim model_providers.yaml > > 详细配置请参考[README.md](../model-providers/README.md) +- 初始化知识库 +```shell +chatchat-kb -r +``` + - 启动服务 ```shell chatchat -a diff --git a/libs/chatchat-server/chatchat/config_work_space.py b/libs/chatchat-server/chatchat/config_work_space.py new file mode 100644 index 00000000..6b3f2907 --- /dev/null +++ b/libs/chatchat-server/chatchat/config_work_space.py @@ -0,0 +1,47 @@ +from chatchat.configs import config_workspace as workspace + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="指令` chatchat-config` 工作空间配置") + # 只能选择true或false + parser.add_argument( + "-v", + "--verbose", + choices=["true", "false"], + help="是否开启详细日志" + ) + parser.add_argument( + "-d", + "--data", + help="数据存放路径" + ) + parser.add_argument( + "-f", + "--format", + help="日志格式" + ) + parser.add_argument( + "--clear", + action="store_true", + help="清除配置" + ) + args = parser.parse_args() + + if args.verbose: + if args.verbose.lower() == "true": + workspace.set_log_verbose(True) + else: + workspace.set_log_verbose(False) + if args.data: + workspace.set_data_path(args.data) + if args.format: + workspace.set_log_format(args.format) + if args.clear: + workspace.clear() + print(workspace.get_config()) + + +if __name__ == "__main__": + main() diff --git a/libs/chatchat-server/chatchat/configs/__init__.py b/libs/chatchat-server/chatchat/configs/__init__.py index bca795be..c1191b08 100644 --- a/libs/chatchat-server/chatchat/configs/__init__.py +++ b/libs/chatchat-server/chatchat/configs/__init__.py @@ -1,8 +1,10 @@ import importlib import importlib.util import os +from pathlib import Path from typing import Dict, Any +import json import logging logger = logging.getLogger() @@ -38,7 +40,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict: ) user_import = False if user_import: - # Dynamic loading {config}.py file py_path = os.path.join(user_config_path, import_config_mod + ".py") spec = importlib.util.spec_from_file_location( @@ -69,7 +70,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict: ) raise RuntimeError(f"Failed to load user config from {user_config_path}") # 当前文件路径 - py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py") + py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py") spec = importlib.util.spec_from_file_location(f"*", py_path) @@ -95,75 +96,108 @@ CONFIG_IMPORTS = { } +def _import_ConfigBasic() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigBasic = load_mod(basic_config_load.get("module"), "ConfigBasic") + + return ConfigBasic + + +def _import_ConfigBasicFactory() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigBasicFactory = load_mod(basic_config_load.get("module"), "ConfigBasicFactory") + + return ConfigBasicFactory + + +def _import_ConfigWorkSpace() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace") + + return ConfigWorkSpace + + +def _import_config_workspace() -> Any: + basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") + load_mod = basic_config_load.get("load_mod") + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace + def _import_log_verbose() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - log_verbose = load_mod(basic_config_load.get("module"), "log_verbose") - - return log_verbose + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace.get_config().log_verbose def _import_chatchat_root() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - CHATCHAT_ROOT = load_mod(basic_config_load.get("module"), "CHATCHAT_ROOT") + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return CHATCHAT_ROOT + return config_workspace.get_config().CHATCHAT_ROOT def _import_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - DATA_PATH = load_mod(basic_config_load.get("module"), "DATA_PATH") - return DATA_PATH + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace.get_config().DATA_PATH + def _import_img_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - IMG_DIR = load_mod(basic_config_load.get("module"), "IMG_DIR") - return IMG_DIR + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().IMG_DIR def _import_nltk_data_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - NLTK_DATA_PATH = load_mod(basic_config_load.get("module"), "NLTK_DATA_PATH") + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") - return NLTK_DATA_PATH + return config_workspace.get_config().NLTK_DATA_PATH def _import_log_format() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - LOG_FORMAT = load_mod(basic_config_load.get("module"), "LOG_FORMAT") - return LOG_FORMAT + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().LOG_FORMAT def _import_log_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - LOG_PATH = load_mod(basic_config_load.get("module"), "LOG_PATH") - return LOG_PATH + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().LOG_PATH def _import_media_path() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - MEDIA_PATH = load_mod(basic_config_load.get("module"), "MEDIA_PATH") - return MEDIA_PATH + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + + return config_workspace.get_config().MEDIA_PATH def _import_base_temp_dir() -> Any: basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") load_mod = basic_config_load.get("load_mod") - BASE_TEMP_DIR = load_mod(basic_config_load.get("module"), "BASE_TEMP_DIR") - - return BASE_TEMP_DIR + config_workspace = load_mod(basic_config_load.get("module"), "config_workspace") + return config_workspace.get_config().BASE_TEMP_DIR def _import_default_knowledge_base() -> Any: @@ -285,6 +319,7 @@ def _import_db_root_path() -> Any: return DB_ROOT_PATH + def _import_sqlalchemy_database_uri() -> Any: kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") load_mod = kb_config_load.get("load_mod") @@ -478,7 +513,15 @@ def _import_api_server() -> Any: def __getattr__(name: str) -> Any: - if name == "log_verbose": + if name == "ConfigBasic": + return _import_ConfigBasic() + elif name == "ConfigBasicFactory": + return _import_ConfigBasicFactory() + elif name == "ConfigWorkSpace": + return _import_ConfigWorkSpace() + elif name == "config_workspace": + return _import_config_workspace() + elif name == "log_verbose": return _import_log_verbose() elif name == "CHATCHAT_ROOT": return _import_chatchat_root() @@ -578,6 +621,7 @@ VERSION = "v0.3.0-preview" __all__ = [ "VERSION", + "config_workspace", "log_verbose", "CHATCHAT_ROOT", "DATA_PATH", @@ -626,5 +670,8 @@ __all__ = [ "WEBUI_SERVER", "API_SERVER", + "ConfigBasic", + "ConfigBasicFactory", + "ConfigWorkSpace", ] diff --git a/libs/chatchat-server/chatchat/configs/_basic_config.py b/libs/chatchat-server/chatchat/configs/_basic_config.py index 46eec939..61108bde 100644 --- a/libs/chatchat-server/chatchat/configs/_basic_config.py +++ b/libs/chatchat-server/chatchat/configs/_basic_config.py @@ -1,55 +1,180 @@ -import logging import os +import json from pathlib import Path +import logging -import langchain - - -# 是否显示详细日志 -log_verbose = False -langchain.verbose = False - -# 通常情况下不需要更改以下内容 - -# chatchat 项目根目录 -CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) - -# 用户数据根目录 -DATA_PATH = os.path.join(CHATCHAT_ROOT, "data") -if not os.path.exists(DATA_PATH): - os.mkdir(DATA_PATH) - -# 项目相关图片 -IMG_DIR = os.path.join(CHATCHAT_ROOT, "img") -if not os.path.exists(IMG_DIR): - os.mkdir(IMG_DIR) - -# nltk 模型存储路径 -NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data") -import nltk -nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path - -# 日志格式 -LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" logger = logging.getLogger() -logger.setLevel(logging.INFO) -logging.basicConfig(format=LOG_FORMAT) -# 日志存储路径 -LOG_PATH = os.path.join(DATA_PATH, "logs") -if not os.path.exists(LOG_PATH): - os.mkdir(LOG_PATH) +class ConfigBasic: + log_verbose: bool + """是否开启日志详细信息""" + CHATCHAT_ROOT: str + """项目根目录""" + DATA_PATH: str + """用户数据根目录""" + IMG_DIR: str + """项目相关图片""" + NLTK_DATA_PATH: str + """nltk 模型存储路径""" + LOG_FORMAT: str + """日志格式""" + LOG_PATH: str + """日志存储路径""" + MEDIA_PATH: str + """模型生成内容(图片、视频、音频等)保存位置""" + BASE_TEMP_DIR: str + """临时文件目录,主要用于文件对话""" -# 模型生成内容(图片、视频、音频等)保存位置 -MEDIA_PATH = os.path.join(DATA_PATH, "media") -if not os.path.exists(MEDIA_PATH): - os.mkdir(MEDIA_PATH) - os.mkdir(os.path.join(MEDIA_PATH, "image")) - os.mkdir(os.path.join(MEDIA_PATH, "audio")) - os.mkdir(os.path.join(MEDIA_PATH, "video")) + def __str__(self): + return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})" -# 临时文件目录,主要用于文件对话 -BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp") -if not os.path.exists(BASE_TEMP_DIR): - os.mkdir(BASE_TEMP_DIR) + +class ConfigBasicFactory: + """Basic config for ChatChat """ + + def __init__(self): + # 日志格式 + self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" + logging.basicConfig(format=self.LOG_FORMAT) + self.LOG_VERBOSE = False + self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) + # 用户数据根目录 + self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") + self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data") + if not os.path.exists(self._DATA_PATH): + os.makedirs(self._DATA_PATH, exist_ok=True) + + self._init_data_dir() + + # 项目相关图片 + self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img") + if not os.path.exists(self.IMG_DIR): + os.makedirs(self.IMG_DIR, exist_ok=True) + + def log_verbose(self, verbose: bool): + self.LOG_VERBOSE = verbose + + def data_path(self, path: str): + self.DATA_PATH = path + if not os.path.exists(self.DATA_PATH): + os.makedirs(self.DATA_PATH, exist_ok=True) + # 复制_DATA_PATH数据到DATA_PATH + if self._DATA_PATH != self.DATA_PATH: + os.system(f"cp -r {self._DATA_PATH}/* {self.DATA_PATH}") + + self._init_data_dir() + + def log_format(self, log_format: str): + self.LOG_FORMAT = log_format + logging.basicConfig(format=self.LOG_FORMAT) + + def _init_data_dir(self): + logger.info(f"Init data dir: {self.DATA_PATH}") + # nltk 模型存储路径 + self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data") + import nltk + nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path + # 日志存储路径 + self.LOG_PATH = os.path.join(self.DATA_PATH, "logs") + if not os.path.exists(self.LOG_PATH): + os.makedirs(self.LOG_PATH, exist_ok=True) + + # 模型生成内容(图片、视频、音频等)保存位置 + self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media") + if not os.path.exists(self.MEDIA_PATH): + os.makedirs(self.MEDIA_PATH, exist_ok=True) + os.makedirs(os.path.join(self.MEDIA_PATH, "image"), exist_ok=True) + os.makedirs(os.path.join(self.MEDIA_PATH, "audio"), exist_ok=True) + os.makedirs(os.path.join(self.MEDIA_PATH, "video"), exist_ok=True) + + # 临时文件目录,主要用于文件对话 + self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp") + if not os.path.exists(self.BASE_TEMP_DIR): + os.makedirs(self.BASE_TEMP_DIR, exist_ok=True) + + logger.info(f"Init data dir: {self.DATA_PATH} success.") + + def get_config(self) -> ConfigBasic: + config = ConfigBasic() + config.log_verbose = self.LOG_VERBOSE + config.CHATCHAT_ROOT = self.CHATCHAT_ROOT + config.DATA_PATH = self.DATA_PATH + config.IMG_DIR = self.IMG_DIR + config.NLTK_DATA_PATH = self.NLTK_DATA_PATH + config.LOG_FORMAT = self.LOG_FORMAT + config.LOG_PATH = self.LOG_PATH + config.MEDIA_PATH = self.MEDIA_PATH + config.BASE_TEMP_DIR = self.BASE_TEMP_DIR + return config + + +class ConfigWorkSpace: + """ + 工作空间的配置预设,提供ConfigBasic建造方法产生实例。 + 该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等 + 工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。 + 注意:不存在则读取默认 + """ + _config_factory: ConfigBasicFactory = ConfigBasicFactory() + + def __init__(self): + self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace") + if not os.path.exists(self.workspace): + os.makedirs(self.workspace, exist_ok=True) + self.workspace_config = os.path.join(self.workspace, "workspace_config.json") + # 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化 + + config_json = self._load_config() + + if config_json: + + _config_factory = ConfigBasicFactory() + if config_json.get("log_verbose"): + _config_factory.log_verbose(config_json.get("log_verbose")) + if config_json.get("DATA_PATH"): + _config_factory.data_path(config_json.get("DATA_PATH")) + if config_json.get("LOG_FORMAT"): + _config_factory.log_format(config_json.get("LOG_FORMAT")) + + self._config_factory = _config_factory + + def get_config(self) -> ConfigBasic: + return self._config_factory.get_config() + + def set_log_verbose(self, verbose: bool): + self._config_factory.log_verbose(verbose) + self._store_config() + + def set_data_path(self, path: str): + self._config_factory.data_path(path) + self._store_config() + + def set_log_format(self, log_format: str): + self._config_factory.log_format(log_format) + self._store_config() + + def clear(self): + logger.info("Clear workspace config.") + os.remove(self.workspace_config) + + def _load_config(self): + try: + with open(self.workspace_config, "r") as f: + return json.loads(f.read()) + except FileNotFoundError: + return None + + def _store_config(self): + with open(self.workspace_config, "w") as f: + config = self._config_factory.get_config() + config_json = { + "log_verbose": config.log_verbose, + "CHATCHAT_ROOT": config.CHATCHAT_ROOT, + "DATA_PATH": config.DATA_PATH, + "LOG_FORMAT": config.LOG_FORMAT + } + f.write(json.dumps(config_json, indent=4, ensure_ascii=False)) + + +config_workspace: ConfigWorkSpace = ConfigWorkSpace() diff --git a/libs/chatchat-server/chatchat/configs/_kb_config.py b/libs/chatchat-server/chatchat/configs/_kb_config.py index 9fed2715..40b4626b 100644 --- a/libs/chatchat-server/chatchat/configs/_kb_config.py +++ b/libs/chatchat-server/chatchat/configs/_kb_config.py @@ -1,11 +1,13 @@ import os from pathlib import Path -# chatchat 项目根目录 -CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) +import sys +sys.path.append(str(Path(__file__).parent)) + +from _basic_config import config_workspace # 用户数据根目录 -DATA_PATH = os.path.join(CHATCHAT_ROOT, "data") +DATA_PATH = config_workspace.get_config().DATA_PATH # 默认使用的知识库 DEFAULT_KNOWLEDGE_BASE = "samples" diff --git a/libs/chatchat-server/chatchat/configs/_model_config.py b/libs/chatchat-server/chatchat/configs/_model_config.py index 856981e9..afb0e8ea 100644 --- a/libs/chatchat-server/chatchat/configs/_model_config.py +++ b/libs/chatchat-server/chatchat/configs/_model_config.py @@ -118,6 +118,25 @@ MODEL_PLATFORMS = [ "tts_models": [], }, + { + "platform_name": "xinference", + "platform_type": "xinference", + "api_base_url": "http://127.0.0.1:9997/v1", + "api_key": "EMPTY", + "api_concurrencies": 5, + "llm_models": [ + "glm-4", + "qwen2-instruct", + "qwen1.5-chat", + ], + "embed_models": [ + "bge-large-zh-v1.5", + ], + "image_models": [], + "reranking_models": [], + "speech2text_models": [], + "tts_models": [], + }, ] @@ -130,7 +149,7 @@ TOOL_CONFIG = { "search_local_knowledgebase": { "use": False, "top_k": 3, - "score_threshold": 1, + "score_threshold": 1.0, "conclude_prompt": { "with_result": '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",' @@ -208,5 +227,34 @@ TOOL_CONFIG = { "text2images": { "use": False, }, - + # text2sql使用建议 + # 1、因大模型生成的sql可能与预期有偏差,请务必在测试环境中进行充分测试、评估; + # 2、生产环境中,对于查询操作,由于不确定查询效率,推荐数据库采用主从数据库架构,让text2sql连接从数据库,防止可能的慢查询影响主业务; + # 3、对于写操作应保持谨慎,如不需要写操作,设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作; + # 4、text2sql与大模型在意图理解、sql转换等方面的能力有关,可切换不同大模型进行测试; + # 5、数据库表名、字段名应与其实际作用保持一致、容易理解,且应对数据库表名、字段进行详细的备注说明,帮助大模型更好理解数据库结构; + # 6、若现有数据库表名难于让大模型理解,可配置下面table_comments字段,补充说明某些表的作用。 + "text2sql": { + "use": False, + # SQLAlchemy连接字符串,支持的数据库有: + # crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb + # 不同的数据库请查询SQLAlchemy,修改sqlalchemy_connect_str,配置对应的数据库连接,如sqlite为sqlite:///数据库文件路径,下面示例为mysql + # 如提示缺少对应数据库的驱动,请自行通过poetry安装 + "sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e", + # 务必评估是否需要开启read_only,开启后会对sql语句进行检查,请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求 + # 优先推荐从数据库层面对用户权限进行限制 + "read_only": False, + #限定返回的行数 + "top_k":50, + #是否返回中间步骤 + "return_intermediate_steps": True, + #如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表 + "table_names":[], + #对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。 + "table_comments":{ + # 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明 + # "tableA":"这是一个用户表,存储了用户的基本信息", + # "tanleB":"角色表", + } + }, } diff --git a/libs/chatchat-server/chatchat/configs/model_providers.yaml b/libs/chatchat-server/chatchat/configs/model_providers.yaml index b47142fc..4032a21e 100644 --- a/libs/chatchat-server/chatchat/configs/model_providers.yaml +++ b/libs/chatchat-server/chatchat/configs/model_providers.yaml @@ -20,16 +20,16 @@ xinference: model_credential: - - model: 'chatglm3-6b' + - model: 'glm-4' model_type: 'llm' model_credentials: server_url: 'http://127.0.0.1:9997/' - model_uid: 'chatglm3-6b' - - model: 'Qwen1.5-14B-Chat' + model_uid: 'glm-4' + - model: 'qwen1.5-chat' model_type: 'llm' model_credentials: server_url: 'http://127.0.0.1:9997/' - model_uid: 'Qwen1.5-14B-Chat' + model_uid: 'qwen1.5-chat' - model: 'bge-large-zh-v1.5' model_type: 'embeddings' model_credentials: @@ -46,4 +46,4 @@ xinference: # model_type: 'llm' # model_credentials: # base_url: 'http://172.21.192.1:11434' -# mode: 'completion' \ No newline at end of file +# mode: 'completion' diff --git a/libs/chatchat-server/chatchat/init_database.py b/libs/chatchat-server/chatchat/init_database.py index 7a1baec7..d08ec87a 100644 --- a/libs/chatchat-server/chatchat/init_database.py +++ b/libs/chatchat-server/chatchat/init_database.py @@ -1,13 +1,12 @@ # Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作 +from datetime import datetime +import multiprocessing as mp from typing import Dict + from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, folder2db, prune_db_docs, prune_folder_files) -from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS -import multiprocessing as mp -import logging -logger = logging.getLogger(__name__) +from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger -from datetime import datetime def run_init_model_provider( @@ -34,7 +33,7 @@ def run_init_model_provider( provider_port=provider_port) -if __name__ == "__main__": +def main(): import argparse parser = argparse.ArgumentParser(description="please specify only one operate method once time.") @@ -186,3 +185,7 @@ if __name__ == "__main__": for p in processes.values(): logger.info("Process status: %s", p) + + +if __name__ == "__main__": + main() diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py index 8faaedb7..2242434f 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/__init__.py @@ -10,3 +10,4 @@ from .text2image import text2images from .vqa_processor import vqa_processor from .aqa_processor import aqa_processor +from .text2sql import text2sql \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index e7524f2a..5f8fde72 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -1,13 +1,14 @@ from urllib.parse import urlencode from chatchat.server.utils import get_tool_config from chatchat.server.pydantic_v1 import Field -from .tools_registry import regist_tool, BaseToolOutput +from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput from chatchat.server.knowledge_base.kb_api import list_kbs from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId from chatchat.configs import KB_INFO -template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." +template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on " + "this knowledge use this tool. The 'database' should be one of the above [{key}].") KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) template_knowledge = template.format(KB_info=KB_info_str, key="samples") @@ -49,7 +50,7 @@ def search_local_knowledgebase( database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]), query: str = Field(description="Query for Knowledge Search"), ): - '''''' + """""" tool_config = get_tool_config("search_local_knowledgebase") ret = search_knowledgebase(query=query, database=database, config=tool_config) return KBToolOutput(ret, database=database) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py new file mode 100644 index 00000000..ef8ec878 --- /dev/null +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py @@ -0,0 +1,103 @@ +from langchain.utilities import SQLDatabase +from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialChain +from chatchat.server.utils import get_tool_config +from chatchat.server.pydantic_v1 import Field +from .tools_registry import regist_tool, BaseToolOutput +from sqlalchemy.exc import OperationalError +from sqlalchemy import event +from langchain_core.prompts.prompt import PromptTemplate +from langchain.chains import LLMChain + +READ_ONLY_PROMPT_TEMPLATE="""You are a MySQL expert. The database is currently in read-only mode. +Given an input question, determine if the related SQL can be executed in read-only mode. +If the SQL can be executed normally, return Answer:'SQL can be executed normally'. +If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'. +Use the following format: + +Answer: Final answer here + +Question: {query} +""" + +# 定义一个拦截器函数来检查SQL语句,以支持read-only,可修改下面的write_operations,以匹配你使用的数据库写操作关键字 +def intercept_sql(conn, cursor, statement, parameters, context, executemany): + # List of SQL keywords that indicate a write operation + write_operations = ("insert", "update", "delete", "create", "drop", "alter", "truncate", "rename") + # Check if the statement starts with any of the write operation keywords + if any(statement.strip().lower().startswith(op) for op in write_operations): + raise OperationalError("Database is read-only. Write operations are not allowed.", params=None, orig=None) + +def query_database(query: str, + config: dict): + top_k = config["top_k"] + return_intermediate_steps = config["return_intermediate_steps"] + sqlalchemy_connect_str = config["sqlalchemy_connect_str"] + read_only = config["read_only"] + db = SQLDatabase.from_uri(sqlalchemy_connect_str) + + from chatchat.server.api_server.chat_routes import global_model_name + from chatchat.server.utils import get_ChatOpenAI + llm = get_ChatOpenAI( + model_name=global_model_name, + temperature=0, + streaming=True, + local_wrap=True, + verbose=True + ) + table_names=config["table_names"] + table_comments=config["table_comments"] + result = None + + #如果发现大模型判断用什么表出现问题,尝试给langchain提供额外的表说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判 + #由于langchain固定了输入参数,所以只能通过query传递额外的表说明 + if table_comments: + TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n" + table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()]) + query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n" + + if read_only: + # 在read_only下,先让大模型判断只读模式是否能满足需求,避免后续执行过程报错,返回友好提示。 + READ_ONLY_PROMPT = PromptTemplate( + input_variables=["query"], + template=READ_ONLY_PROMPT_TEMPLATE, + ) + read_only_chain = LLMChain( + prompt=READ_ONLY_PROMPT, + llm=llm, + ) + read_only_result = read_only_chain.invoke(query) + if "SQL cannot be executed normally" in read_only_result["text"]: + return "当前数据库为只读状态,无法满足您的需求!" + + # 当然大模型不能保证完全判断准确,为防止大模型判断有误,再从拦截器层面拒绝写操作 + event.listen(db._engine, "before_cursor_execute", intercept_sql) + + #如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain + #这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源 + #如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断 + if len(table_names) > 0: + db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps) + result = db_chain.invoke({"query":query,"table_names_to_use":table_names}) + else: + #先预测会使用哪些表,然后再将问题和预测的表给大模型 + db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps) + result = db_chain.invoke(query) + + context = f"""查询结果:{result['result']}\n\n""" + + intermediate_steps=result["intermediate_steps"] + #如果存在intermediate_steps,且这个数组的长度大于2,则保留最后两个元素,因为前面几个步骤存在示例数据,容易引起误解 + if intermediate_steps: + if len(intermediate_steps)>2: + sql_detail=intermediate_steps[-2:-1][0]["input"] + # sql_detail截取从SQLQuery到Answer:之间的内容 + sql_detail=sql_detail[sql_detail.find("SQLQuery:")+9:sql_detail.find("Answer:")] + context = context+"执行的sql:'"+sql_detail+"'\n\n" + return context + + +@regist_tool(title="Text2Sql") +def text2sql(query: str = Field(description="No need for SQL statements,just input the natural language that you want to chat with database")): + '''Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result.''' + tool_config = get_tool_config("text2sql") + return BaseToolOutput(query_database(query=query, config=tool_config)) diff --git a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py index f935ab9c..eb2c74ca 100644 --- a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py +++ b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py @@ -28,6 +28,8 @@ chat_router.post("/file_chat", summary="文件对话" )(file_chat) +#定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name +global_model_name=None @chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") async def chat_completions( @@ -51,6 +53,8 @@ async def chat_completions( for key in list(extra): delattr(body, key) + global global_model_name + global_model_name=body.model # check tools & tool_choice in request body if isinstance(body.tool_choice, str): if t := get_tool(body.tool_choice): diff --git a/libs/chatchat-server/chatchat/server/chat/file_chat.py b/libs/chatchat-server/chatchat/server/chat/file_chat.py index f2a8e67a..a0e67d0d 100644 --- a/libs/chatchat-server/chatchat/server/chat/file_chat.py +++ b/libs/chatchat-server/chatchat/server/chat/file_chat.py @@ -63,10 +63,10 @@ def upload_temp_docs( chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), ) -> BaseResponse: - ''' + """ 将文件保存到临时目录,并进行向量化。 返回临时目录名称作为ID,同时也是临时向量库的ID。 - ''' + """ if prev_id is not None: memo_faiss_pool.pop(prev_id) @@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= docs = [x[0] for x in docs] context = "\n".join([doc.page_content for doc in docs]) - if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板 + if len(docs) == 0: # 如果没有找到相关文档,使用Empty模板 prompt_template = get_prompt_template("knowledge_base_chat", "empty") else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) diff --git a/libs/chatchat-server/chatchat/server/file_rag/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/chatchat-server/chatchat/server/document_loaders/FilteredCSVloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/FilteredCSVloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/FilteredCSVloader.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/__init__.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/__init__.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/mydocloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/mydocloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/mydocloader.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/myimgloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py similarity index 92% rename from libs/chatchat-server/chatchat/server/document_loaders/myimgloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py index 6b195cce..c6fda01e 100644 --- a/libs/chatchat-server/chatchat/server/document_loaders/myimgloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/myimgloader.py @@ -1,6 +1,6 @@ from typing import List from langchain_community.document_loaders.unstructured import UnstructuredFileLoader -from chatchat.server.document_loaders.ocr import get_ocr +from chatchat.server.file_rag.document_loaders.ocr import get_ocr class RapidOCRLoader(UnstructuredFileLoader): diff --git a/libs/chatchat-server/chatchat/server/document_loaders/mypdfloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py similarity index 98% rename from libs/chatchat-server/chatchat/server/document_loaders/mypdfloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py index 9e8796a4..c6a178f8 100644 --- a/libs/chatchat-server/chatchat/server/document_loaders/mypdfloader.py +++ b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypdfloader.py @@ -4,7 +4,7 @@ import cv2 from PIL import Image import numpy as np from chatchat.configs import PDF_OCR_THRESHOLD -from chatchat.server.document_loaders.ocr import get_ocr +from chatchat.server.file_rag.document_loaders.ocr import get_ocr import tqdm diff --git a/libs/chatchat-server/chatchat/server/document_loaders/mypptloader.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/mypptloader.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/mypptloader.py diff --git a/libs/chatchat-server/chatchat/server/document_loaders/ocr.py b/libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py similarity index 100% rename from libs/chatchat-server/chatchat/server/document_loaders/ocr.py rename to libs/chatchat-server/chatchat/server/file_rag/document_loaders/ocr.py diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py new file mode 100644 index 00000000..2cf3617f --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/__init__.py @@ -0,0 +1,3 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService +from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py new file mode 100644 index 00000000..7e4d0646 --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/base.py @@ -0,0 +1,24 @@ +from langchain.vectorstores import VectorStore +from abc import ABCMeta, abstractmethod + + +class BaseRetrieverService(metaclass=ABCMeta): + def __init__(self, **kwargs): + self.do_init(**kwargs) + + @abstractmethod + def do_init(self, **kwargs): + pass + + + @abstractmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + pass + + @abstractmethod + def get_relevant_documents(self, query: str): + pass diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py new file mode 100644 index 00000000..cb09b633 --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/ensemble.py @@ -0,0 +1,47 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from langchain.vectorstores import VectorStore +from langchain_core.retrievers import BaseRetriever +from langchain_community.retrievers import BM25Retriever +from langchain.retrievers import EnsembleRetriever + + +class EnsembleRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = retriever + + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + faiss_retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={ + "score_threshold": score_threshold, + "k": top_k + } + ) + # TODO: 换个不用torch的实现方式 + from cutword.cutword import Cutter + cutter = Cutter() + docs = list(vectorstore.docstore._dict.values()) + bm25_retriever = BM25Retriever.from_documents( + docs, + preprocess_func=cutter.cutword + ) + bm25_retriever.k = top_k + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] + ) + return EnsembleRetrieverService(retriever=ensemble_retriever) + + def get_relevant_documents(self, query: str): + return self.retriever.get_relevant_documents(query)[:self.top_k] diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py new file mode 100644 index 00000000..b6d382fa --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/vectorstore.py @@ -0,0 +1,33 @@ +from chatchat.server.file_rag.retrievers.base import BaseRetrieverService +from langchain.vectorstores import VectorStore +from langchain_core.retrievers import BaseRetriever + + +class VectorstoreRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = retriever + + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={ + "score_threshold": score_threshold, + "k": top_k + } + ) + return VectorstoreRetrieverService(retriever=retriever) + + def get_relevant_documents(self, query: str): + return self.retriever.get_relevant_documents(query)[:self.top_k] diff --git a/libs/chatchat-server/chatchat/server/text_splitter/__init__.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/__init__.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/__init__.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/ali_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/ali_text_splitter.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/ali_text_splitter.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/chinese_recursive_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/chinese_recursive_text_splitter.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_recursive_text_splitter.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/chinese_text_splitter.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/chinese_text_splitter.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/chinese_text_splitter.py diff --git a/libs/chatchat-server/chatchat/server/text_splitter/zh_title_enhance.py b/libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py similarity index 100% rename from libs/chatchat-server/chatchat/server/text_splitter/zh_title_enhance.py rename to libs/chatchat-server/chatchat/server/file_rag/text_splitter/zh_title_enhance.py diff --git a/libs/chatchat-server/chatchat/server/file_rag/utils.py b/libs/chatchat-server/chatchat/server/file_rag/utils.py new file mode 100644 index 00000000..ddf64e3d --- /dev/null +++ b/libs/chatchat-server/chatchat/server/file_rag/utils.py @@ -0,0 +1,13 @@ +from chatchat.server.file_rag.retrievers import ( + BaseRetrieverService, + VectorstoreRetrieverService, + EnsembleRetrieverService, +) + +Retrivals = { + "vectorstore": VectorstoreRetrieverService, + "ensemble": EnsembleRetrieverService, +} + +def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService: + return Retrivals[type] \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py index ea3d6089..0ccc0704 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py @@ -14,6 +14,7 @@ def list_kbs(): def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), vector_store_type: str = Body("faiss"), + kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"), embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), ) -> BaseResponse: # Create selected knowledge base @@ -26,7 +27,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), if kb is not None: return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") - kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) + kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info) try: kb.create_kb() except Exception as e: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py index a6d3f425..a11e0054 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py @@ -62,7 +62,7 @@ class CachePool: self._cache_num = cache_num self._cache = OrderedDict() self.atomic = threading.RLock() - + def keys(self) -> List[str]: return list(self._cache.keys()) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py index d37bd29d..36d21162 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py @@ -95,6 +95,7 @@ class KBFaissPool(_FaissPool): embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> ThreadSafeFaiss: self.atomic.acquire() + locked = True vector_name = vector_name or embed_model cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 try: @@ -103,12 +104,14 @@ class KBFaissPool(_FaissPool): self.set((kb_name, vector_name), item) with item.acquire(msg="初始化"): self.atomic.release() + locked = False logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.") vs_path = get_vs_path(kb_name, vector_name) if os.path.isfile(os.path.join(vs_path, "index.faiss")): embeddings = get_Embeddings(embed_model=embed_model) - vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True, + allow_dangerous_deserialization=True) elif create: # create an empty vector store if not os.path.exists(vs_path): @@ -121,8 +124,10 @@ class KBFaissPool(_FaissPool): item.finish_loading() else: self.atomic.release() + locked = False except Exception as e: - self.atomic.release() + if locked: # we don't know exception raised before or after atomic.release + self.atomic.release() logger.error(e) raise RuntimeError(f"向量库 {kb_name} 加载失败。") return self.get((kb_name, vector_name)) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py index 2763028f..3e40786d 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py @@ -1,21 +1,23 @@ +import json import os import urllib +from typing import List, Dict + from fastapi import File, Form, Body, Query, UploadFile +from fastapi.responses import FileResponse +from langchain.docstore.document import Document +from sse_starlette import EventSourceResponse + from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, logger, log_verbose, ) -from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool +from chatchat.server.db.repository.knowledge_file_repository import get_file_detail from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path, files2docs_in_thread, KnowledgeFile) -from fastapi.responses import FileResponse -from sse_starlette import EventSourceResponse -import json -from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory -from chatchat.server.db.repository.knowledge_file_repository import get_file_detail -from langchain.docstore.document import Document +from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId -from typing import List, Dict +from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model def search_docs( @@ -35,7 +37,8 @@ def search_docs( if kb is not None: if query: docs = kb.search_docs(query, top_k, score_threshold) - data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs] elif file_name or metadata: data = kb.list_docs(file_name=file_name, metadata=metadata) for d in data: @@ -55,8 +58,8 @@ def list_files( if kb is None: return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: - all_doc_names = kb.list_files() - return ListResponse(data=all_doc_names) + all_docs = get_kb_file_details(knowledge_base_name) + return ListResponse(data=all_docs) def _save_files_in_thread(files: List[UploadFile], @@ -352,38 +355,42 @@ def recreate_vector_store( if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} else: - if kb.exists(): - kb.clear_vs() - kb.create_kb() - files = list_files_from_folder(knowledge_base_name) - kb_files = [(file, knowledge_base_name) for file in files] - i = 0 - for status, result in files2docs_in_thread(kb_files, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): - if status: - kb_name, file_name, docs = result - kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) - kb_file.splited_docs = docs - yield json.dumps({ - "code": 200, - "msg": f"({i + 1} / {len(files)}): {file_name}", - "total": len(files), - "finished": i + 1, - "doc": file_name, - }, ensure_ascii=False) - kb.add_doc(kb_file, not_refresh_vs_cache=True) - else: - kb_name, file_name, error = result - msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" - logger.error(msg) - yield json.dumps({ - "code": 500, - "msg": msg, - }) - i += 1 - if not not_refresh_vs_cache: - kb.save_vector_store() + error_msg = f"could not recreate vector store because failed to access embed model." + if not kb.check_embed_model(error_msg): + yield {"code": 404, "msg": error_msg} + else: + if kb.exists(): + kb.clear_vs() + kb.create_kb() + files = list_files_from_folder(knowledge_base_name) + kb_files = [(file, knowledge_base_name) for file in files] + i = 0 + for status, result in files2docs_in_thread(kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance): + if status: + kb_name, file_name, docs = result + kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) + kb_file.splited_docs = docs + yield json.dumps({ + "code": 200, + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), + "finished": i + 1, + "doc": file_name, + }, ensure_ascii=False) + kb.add_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" + logger.error(msg) + yield json.dumps({ + "code": 500, + "msg": msg, + }) + i += 1 + if not not_refresh_vs_cache: + kb.save_vector_store() return EventSourceResponse(output()) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py index 7cdbf102..ac21863e 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py @@ -5,6 +5,11 @@ import os from pathlib import Path from langchain.docstore.document import Document +from typing import List, Union, Dict, Optional, Tuple + +from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + DEFAULT_EMBEDDING_MODEL, KB_INFO, logger) +from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema from chatchat.server.db.repository.knowledge_base_repository import ( add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db, get_kb_detail, @@ -14,18 +19,12 @@ from chatchat.server.db.repository.knowledge_file_repository import ( count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db, list_docs_from_db, ) - -from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - DEFAULT_EMBEDDING_MODEL, KB_INFO) from chatchat.server.knowledge_base.utils import ( get_kb_path, get_doc_path, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, ) - -from typing import List, Union, Dict, Optional, Tuple - from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId -from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema +from chatchat.server.utils import check_embed_model as _check_embed_model class SupportedVSType: FAISS = 'faiss' @@ -41,10 +40,11 @@ class KBService(ABC): def __init__(self, knowledge_base_name: str, + kb_info: str = None, embed_model: str = DEFAULT_EMBEDDING_MODEL, ): self.kb_name = knowledge_base_name - self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") + self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") self.embed_model = embed_model self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) @@ -59,6 +59,13 @@ class KBService(ABC): ''' pass + def check_embed_model(self, error_msg: str) -> bool: + if not _check_embed_model(self.embed_model): + logger.error(error_msg, exc_info=True) + return False + else: + return True + def create_kb(self): """ 创建知识库 @@ -93,6 +100,9 @@ class KBService(ABC): 向知识库添加文件 如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True """ + if not self.check_embed_model(f"could not add docs because failed to access embed model."): + return False + if docs: custom_docs = True else: @@ -143,6 +153,9 @@ class KBService(ABC): 使用content中的文件更新向量库 如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True """ + if not self.check_embed_model(f"could not update docs because failed to access embed model."): + return False + if os.path.exists(kb_file.filepath): self.delete_doc(kb_file, **kwargs) return self.add_doc(kb_file, docs=docs, **kwargs) @@ -162,6 +175,8 @@ class KBService(ABC): top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: float = SCORE_THRESHOLD, ) ->List[Document]: + if not self.check_embed_model(f"could not search docs because failed to access embed model."): + return [] docs = self.do_search(query, top_k, score_threshold) return docs @@ -176,6 +191,9 @@ class KBService(ABC): 传入参数为: {doc_id: Document, ...} 如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档 ''' + if not self.check_embed_model(f"could not update docs because failed to access embed model."): + return False + self.del_doc_by_ids(list(docs.keys())) docs = [] ids = [] @@ -282,31 +300,32 @@ class KBServiceFactory: def get_service(kb_name: str, vector_store_type: Union[str, SupportedVSType], embed_model: str = DEFAULT_EMBEDDING_MODEL, + kb_info: str = None, ) -> KBService: if isinstance(vector_store_type, str): vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) + params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info} if SupportedVSType.FAISS == vector_store_type: from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService - return FaissKBService(kb_name, embed_model=embed_model) + return FaissKBService(**params) elif SupportedVSType.PG == vector_store_type: from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService - return PGKBService(kb_name, embed_model=embed_model) + return PGKBService(**params) elif SupportedVSType.MILVUS == vector_store_type: from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name, embed_model=embed_model) + return MilvusKBService(**params) elif SupportedVSType.ZILLIZ == vector_store_type: from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService - return ZillizKBService(kb_name, embed_model=embed_model) + return ZillizKBService(**params) elif SupportedVSType.DEFAULT == vector_store_type: from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name, - embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config + return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config elif SupportedVSType.ES == vector_store_type: from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService - return ESKBService(kb_name, embed_model=embed_model) + return ESKBService(**params) elif SupportedVSType.CHROMADB == vector_store_type: from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService - return ChromaKBService(kb_name, embed_model=embed_model) + return ChromaKBService(**params) elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService return DefaultKBService(kb_name) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py index 0834c87d..c6c46622 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/chromadb_kb_service.py @@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.utils import get_Embeddings +from chatchat.server.file_rag.utils import get_Retriever def _get_result_to_documents(get_result: GetResult) -> List[Document]: @@ -75,10 +76,13 @@ class ChromaKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[ Tuple[Document, float]]: - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k) - return _results_to_docs_and_scores(query_result) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.collection, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: doc_infos = [] diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py index 19813bf1..aef63710 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/es_kb_service.py @@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings from elasticsearch import Elasticsearch, BadRequestError from chatchat.configs import kbs_config, KB_ROOT_PATH +from chatchat.server.file_rag.utils import get_Retriever import logging @@ -107,8 +108,12 @@ class ESKBService(KBService): def do_search(self, query:str, top_k: int, score_threshold: float): # 文本相似性检索 - docs = self.db.similarity_search_with_score(query=query, - k=top_k) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.db, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) return docs def get_doc_by_ids(self, ids: List[str]) -> List[Document]: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py index 95f7cd64..52738ae8 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/faiss_kb_service.py @@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path -from chatchat.server.utils import get_Embeddings from langchain.docstore.document import Document -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Tuple +from chatchat.server.file_rag.utils import get_Retriever class FaissKBService(KBService): @@ -62,10 +62,13 @@ class FaissKBService(KBService): top_k: int, score_threshold: float = SCORE_THRESHOLD, ) -> List[Tuple[Document, float]]: - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) with self.load_vector_store().acquire() as vs: - docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) + retriever = get_Retriever("ensemble").from_vectorstore( + vs, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) return docs def do_add_doc(self, diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py index ab0a77e9..8eddb5f4 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py @@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ score_threshold_process from chatchat.server.knowledge_base.utils import KnowledgeFile -from chatchat.server.utils import get_Embeddings +from chatchat.server.file_rag.utils import get_Retriever class MilvusKBService(KBService): @@ -67,10 +67,16 @@ class MilvusKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_milvus() - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) - return score_threshold_process(score_threshold, top_k, docs) + # embed_func = get_Embeddings(self.embed_model) + # embeddings = embed_func.embed_query(query) + # docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.milvus, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: for doc in docs: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py index 8c3a0cf6..473c7f30 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/pg_kb_service.py @@ -15,6 +15,7 @@ import shutil import sqlalchemy from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session +from chatchat.server.file_rag.utils import get_Retriever class PGKBService(KBService): @@ -60,10 +61,13 @@ class PGKBService(KBService): shutil.rmtree(self.kb_path) def do_search(self, query: str, top_k: int, score_threshold: float): - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) - return score_threshold_process(score_threshold, top_k, docs) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.pg_vector, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: ids = self.pg_vector.add_documents(docs) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py index 51e21b10..336eaa48 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -1,5 +1,4 @@ -from typing import List, Dict, Optional -from langchain.embeddings.base import Embeddings +from typing import List, Dict from langchain.schema import Document from langchain.vectorstores import Zilliz from chatchat.configs import kbs_config @@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV score_threshold_process from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.utils import get_Embeddings +from chatchat.server.file_rag.utils import get_Retriever class ZillizKBService(KBService): @@ -60,10 +60,13 @@ class ZillizKBService(KBService): def do_search(self, query: str, top_k: int, score_threshold: float): self._load_zilliz() - embed_func = get_Embeddings(self.embed_model) - embeddings = embed_func.embed_query(query) - docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k) - return score_threshold_process(score_threshold, top_k, docs) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.zilliz, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: for doc in docs: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py index f1b83fa8..fd18fe17 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py @@ -42,55 +42,59 @@ def recreate_summary_vector_store( if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} else: - # 重新创建知识库 - kb_summary = KBSummaryService(knowledge_base_name, embed_model) - kb_summary.drop_kb_summary() - kb_summary.create_kb_summary() + error_msg = f"could not recreate summary vector store because failed to access embed model." + if not kb.check_embed_model(error_msg): + yield {"code": 404, "msg": error_msg} + else: + # 重新创建知识库 + kb_summary = KBSummaryService(knowledge_base_name, embed_model) + kb_summary.drop_kb_summary() + kb_summary.create_kb_summary() - llm = get_ChatOpenAI( - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - local_wrap=True, - ) - reduce_llm = get_ChatOpenAI( - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - local_wrap=True, - ) - # 文本摘要适配器 - summary = SummaryAdapter.form_summary(llm=llm, - reduce_llm=reduce_llm, - overlap_size=OVERLAP_SIZE) - files = list_files_from_folder(knowledge_base_name) + llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + local_wrap=True, + ) + reduce_llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + local_wrap=True, + ) + # 文本摘要适配器 + summary = SummaryAdapter.form_summary(llm=llm, + reduce_llm=reduce_llm, + overlap_size=OVERLAP_SIZE) + files = list_files_from_folder(knowledge_base_name) - i = 0 - for i, file_name in enumerate(files): + i = 0 + for i, file_name in enumerate(files): - doc_infos = kb.list_docs(file_name=file_name) - docs = summary.summarize(file_description=file_description, - docs=doc_infos) + doc_infos = kb.list_docs(file_name=file_name) + docs = summary.summarize(file_description=file_description, + docs=doc_infos) - status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) - if status_kb_summary: - logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成") - yield json.dumps({ - "code": 200, - "msg": f"({i + 1} / {len(files)}): {file_name}", - "total": len(files), - "finished": i + 1, - "doc": file_name, - }, ensure_ascii=False) - else: + status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) + if status_kb_summary: + logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成") + yield json.dumps({ + "code": 200, + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), + "finished": i + 1, + "doc": file_name, + }, ensure_ascii=False) + else: - msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" - logger.error(msg) - yield json.dumps({ - "code": 500, - "msg": msg, - }) - i += 1 + msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" + logger.error(msg) + yield json.dumps({ + "code": 500, + "msg": msg, + }) + i += 1 return EventSourceResponse(output()) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py b/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py index f4d35c40..01811691 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py @@ -1,6 +1,6 @@ from chatchat.configs import ( DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, - CHUNK_SIZE, OVERLAP_SIZE + CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose ) from chatchat.server.knowledge_base.utils import ( get_file_path, list_kbs_from_folder, diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py index c5dd442b..a03393dd 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py @@ -10,7 +10,7 @@ from chatchat.configs import ( TEXT_SPLITTER_NAME, ) import importlib -from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance +from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance import langchain_community.document_loaders from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter diff --git a/libs/chatchat-server/chatchat/server/utils.py b/libs/chatchat-server/chatchat/server/utils.py index 314ef0cc..fdf1a9ab 100644 --- a/libs/chatchat-server/chatchat/server/utils.py +++ b/libs/chatchat-server/chatchat/server/utils.py @@ -244,6 +244,16 @@ def get_Embeddings( logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True) +def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool: + embeddings = get_Embeddings(embed_model=embed_model) + try: + embeddings.embed_query("this is a test") + return True + except Exception as e: + logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True) + return False + + def get_OpenAIClient( platform_name: str = None, model_name: str = None, @@ -696,7 +706,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]: ''' 创建一个临时目录,返回(路径,文件夹名称) ''' - from chatchat.configs.basic_config import BASE_TEMP_DIR + from chatchat.configs import BASE_TEMP_DIR import uuid if id is not None: # 如果指定的临时目录已存在,直接返回 diff --git a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index cdd40aa6..4ec3a80b 100644 --- a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -24,28 +24,28 @@ chat_box = ChatBox( def save_session(): - '''save session state to chat context''' + """save session state to chat context""" chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"]) def restore_session(): - '''restore sesstion state from chat context''' + """restore sesstion state from chat context""" chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"]) def rerun(): - ''' + """ save chat context before rerun - ''' + """ save_session() st.rerun() def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: - ''' + """ 返回消息历史。 content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要 - ''' + """ def filter(msg): content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]] @@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) -> @st.cache_data def upload_temp_docs(files, _api: ApiRequest) -> str: - ''' + """ 将文件上传到临时目录,用于文件对话 返回临时向量库ID - ''' + """ return _api.upload_temp_docs(files).get("data", {}).get("id") @@ -157,11 +157,13 @@ def dialogue_page( tools = list_tools(api) tool_names = ["None"] + list(tools) if use_agent: - # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools") + # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", + # check_all=True, key="selected_tools") selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], key="selected_tools") else: - # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool") + # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", + # key="selected_tool") selected_tool = st.selectbox("选择工具", tool_names, format_func=lambda x: tools.get(x, {"title": "None"})["title"], key="selected_tool") @@ -338,7 +340,7 @@ def dialogue_page( elif d.status == AgentStatus.agent_finish: text = d.choices[0].delta.content or "" chat_box.update_msg(text.replace("\n", "\n\n")) - elif d.status == None: # not agent chat + elif d.status is None: # not agent chat if getattr(d, "is_ref", False): chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", title="参考资料")) diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index cf980bf5..279ef64f 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -10,16 +10,18 @@ packages = [ [tool.poetry.scripts] chatchat = 'chatchat.startup:main' +chatchat-kb = 'chatchat.init_database:main' +chatchat-config = 'chatchat.config_work_space:main' [tool.poetry.dependencies] python = ">=3.8.1,<3.12,!=3.9.7" model-providers = "^0.3.0" -langchain = "0.1.5" +langchain = "0.1.17" langchainhub = "0.1.14" -langchain-community = "0.0.17" +langchain-community = "0.0.36" langchain-openai = "0.0.5" -langchain-experimental = "0.0.50" -fastapi = "0.109.2" +langchain-experimental = "0.0.58" +fastapi = "~0.109.2" sse_starlette = "~1.8.2" nltk = "~3.8.1" uvicorn = ">=0.27.0.post1" @@ -27,6 +29,8 @@ unstructured = "~0.11.0" python-magic-bin = {version = "*", platform = "win32"} SQLAlchemy = "~2.0.25" faiss-cpu = "~1.7.4" +cutword = "0.1.0" +rank_bm25 = "0.2.2" # accelerate = "~0.24.1" # spacy = "~3.7.2" PyMuPDF = "~1.23.16" @@ -34,28 +38,29 @@ rapidocr_onnxruntime = "~1.3.8" requests = "~2.31.0" pathlib = "~1.0.1" pytest = "~7.4.3" -pyjwt = "2.8.0" +pyjwt = "~2.8.0" elasticsearch = "*" numexpr = ">=1" #test strsimpy = ">=0.2.1" markdownify = ">=0.11.6" tqdm = ">=4.66.1" websockets = ">=12.0" -numpy = "1.24.4" +numpy = "~1.24.4" pandas = "~1" # test -pydantic = "2.6.4" +pydantic = "~2.6.4" httpx = {version = ">=0.25.2", extras = ["brotli", "http2", "socks"]} python-multipart = "0.0.9" # webui streamlit = "1.34.0" streamlit-option-menu = "0.3.12" streamlit-antd-components = "0.3.1" -streamlit-chatbox = "1.1.12" +streamlit-chatbox = "1.1.12.post2" streamlit-modal = "0.1.0" streamlit-aggrid = "0.3.4.post3" streamlit-extras = "0.4.2" xinference_client = { version = "^0.11.1", optional = true } zhipuai = { version = "^2.1.0", optional = true } +pymysql = "^1.1.0" [tool.poetry.extras] xinference = ["xinference_client"] @@ -254,7 +259,7 @@ ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy [tool.poetry.plugins.dotenv] ignore = "false" -location = ".env" +dotenv = "dotenv:plugin" # https://python-poetry.org/docs/repositories/ diff --git a/libs/chatchat-server/tests/conftest.py b/libs/chatchat-server/tests/conftest.py new file mode 100644 index 00000000..031384e8 --- /dev/null +++ b/libs/chatchat-server/tests/conftest.py @@ -0,0 +1,90 @@ +"""Configuration for unit tests.""" +import logging +from importlib import util +from typing import Dict, List, Sequence + +import pytest +from pytest import Config, Function, Parser + + +def pytest_addoption(parser: Parser) -> None: + """Add custom command line options to pytest.""" + parser.addoption( + "--only-extended", + action="store_true", + help="Only run extended tests. Does not allow skipping any extended tests.", + ) + parser.addoption( + "--only-core", + action="store_true", + help="Only run core tests. Never runs any extended tests.", + ) + + +def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None: + """Add implementations for handling custom markers. + + At the moment, this adds support for a custom `requires` marker. + + The `requires` marker is used to denote tests that require one or more packages + to be installed to run. If the package is not installed, the test is skipped. + + The `requires` marker syntax is: + + .. code-block:: python + + @pytest.mark.requires("package1", "package2") + def test_something(): + ... + """ + # Mapping from the name of a package to whether it is installed or not. + # Used to avoid repeated calls to `util.find_spec` + required_pkgs_info: Dict[str, bool] = {} + + only_extended = config.getoption("--only-extended") or False + only_core = config.getoption("--only-core") or False + + if only_extended and only_core: + raise ValueError("Cannot specify both `--only-extended` and `--only-core`.") + + for item in items: + requires_marker = item.get_closest_marker("requires") + if requires_marker is not None: + if only_core: + item.add_marker(pytest.mark.skip(reason="Skipping not a core test.")) + continue + + # Iterate through the list of required packages + required_pkgs = requires_marker.args + for pkg in required_pkgs: + # If we haven't yet checked whether the pkg is installed + # let's check it and store the result. + if pkg not in required_pkgs_info: + try: + installed = util.find_spec(pkg) is not None + except Exception: + installed = False + required_pkgs_info[pkg] = installed + + if not required_pkgs_info[pkg]: + if only_extended: + pytest.fail( + f"Package `{pkg}` is not installed but is required for " + f"extended tests. Please install the given package and " + f"try again.", + ) + + else: + # If the package is not installed, we immediately break + # and mark the test as skipped. + item.add_marker( + pytest.mark.skip(reason=f"Requires pkg: `{pkg}`") + ) + break + else: + if only_extended: + item.add_marker( + pytest.mark.skip(reason="Skipping not an extended test.") + ) + + diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py new file mode 100644 index 00000000..96fbd5dc --- /dev/null +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -0,0 +1,56 @@ +from pathlib import Path + +from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace +import os + + +def test_config_factory_def(): + test_config_factory = ConfigBasicFactory() + config: ConfigBasic = test_config_factory.get_config() + assert config is not None + assert config.log_verbose is False + assert config.CHATCHAT_ROOT is not None + assert config.DATA_PATH is not None + assert config.IMG_DIR is not None + assert config.NLTK_DATA_PATH is not None + assert config.LOG_FORMAT is not None + assert config.LOG_PATH is not None + assert config.MEDIA_PATH is not None + + assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "video")) + + +def test_workspace(): + config_workspace = ConfigWorkSpace() + assert config_workspace.get_config() is not None + base_root = os.path.join(Path(__file__).absolute().parent, "chatchat") + config_workspace.set_data_path(os.path.join(base_root, "data")) + config_workspace.set_log_verbose(True) + config_workspace.set_log_format(" %(message)s") + + config: ConfigBasic = config_workspace.get_config() + assert config.log_verbose is True + assert config.DATA_PATH == os.path.join(base_root, "data") + assert config.IMG_DIR is not None + assert config.NLTK_DATA_PATH == os.path.join(base_root, "data", "nltk_data") + assert config.LOG_FORMAT == " %(message)s" + assert config.LOG_PATH == os.path.join(base_root, "data", "logs") + assert config.MEDIA_PATH == os.path.join(base_root, "data", "media") + + assert os.path.exists(os.path.join(config.MEDIA_PATH, "image")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio")) + assert os.path.exists(os.path.join(config.MEDIA_PATH, "video")) + config_workspace.clear() + + +def test_workspace_default(): + from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH) + assert log_verbose is False + assert DATA_PATH is not None + assert IMG_DIR is not None + assert NLTK_DATA_PATH is not None + assert LOG_FORMAT is not None + assert LOG_PATH is not None + assert MEDIA_PATH is not None diff --git a/libs/model-providers/.env b/libs/model-providers/.env new file mode 100644 index 00000000..f78637c0 --- /dev/null +++ b/libs/model-providers/.env @@ -0,0 +1 @@ +PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring diff --git a/libs/model-providers/model_providers/extensions/ext_redis.py b/libs/model-providers/model_providers/extensions/ext_redis.py index 015706e3..7316464e 100644 --- a/libs/model-providers/model_providers/extensions/ext_redis.py +++ b/libs/model-providers/model_providers/extensions/ext_redis.py @@ -1,26 +1,26 @@ -import redis -from redis.connection import Connection, SSLConnection - -redis_client = redis.Redis() - - -def init_app(app): - connection_class = Connection - if app.config.get("REDIS_USE_SSL", False): - connection_class = SSLConnection - - redis_client.connection_pool = redis.ConnectionPool( - **{ - "host": app.config.get("REDIS_HOST", "localhost"), - "port": app.config.get("REDIS_PORT", 6379), - "username": app.config.get("REDIS_USERNAME", None), - "password": app.config.get("REDIS_PASSWORD", None), - "db": app.config.get("REDIS_DB", 0), - "encoding": "utf-8", - "encoding_errors": "strict", - "decode_responses": False, - }, - connection_class=connection_class, - ) - - app.extensions["redis"] = redis_client +# import redis +# from redis.connection import Connection, SSLConnection +# +# redis_client = redis.Redis() +# +# +# def init_app(app): +# connection_class = Connection +# if app.config.get("REDIS_USE_SSL", False): +# connection_class = SSLConnection +# +# redis_client.connection_pool = redis.ConnectionPool( +# **{ +# "host": app.config.get("REDIS_HOST", "localhost"), +# "port": app.config.get("REDIS_PORT", 6379), +# "username": app.config.get("REDIS_USERNAME", None), +# "password": app.config.get("REDIS_PASSWORD", None), +# "db": app.config.get("REDIS_DB", 0), +# "encoding": "utf-8", +# "encoding_errors": "strict", +# "decode_responses": False, +# }, +# connection_class=connection_class, +# ) +# +# app.extensions["redis"] = redis_client diff --git a/libs/model-providers/pyproject.toml b/libs/model-providers/pyproject.toml index 1d4652fc..14976752 100644 --- a/libs/model-providers/pyproject.toml +++ b/libs/model-providers/pyproject.toml @@ -7,20 +7,19 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.8.1,<3.12,!=3.9.7" -transformers = "4.31.0" -fastapi = "^0.109.2" +transformers = "~4.31.0" +fastapi = "~0.109.2" uvicorn = ">=0.27.0.post1" -sse-starlette = "^1.8.2" -pyyaml = "6.0.1" -pydantic = "2.6.4" -redis = "4.5.4" +sse-starlette = "~1.8.2" +pyyaml = "~6.0.1" +pydantic ="~2.6.4" # config manage -omegaconf = "2.0.6" +omegaconf = "~2.0.6" # modle_runtime -openai = "1.13.3" -tiktoken = "0.5.2" -pydub = "0.25.1" -boto3 = "1.28.17" +openai = "~1.13.3" +tiktoken = "~0.5.2" +pydub = "~0.25.1" +boto3 = "~1.28.17" [tool.poetry.group.test.dependencies] # The only dependencies that should be added are @@ -207,7 +206,7 @@ ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy [tool.poetry.plugins.dotenv] ignore = "false" -location = ".env" +dotenv = "dotenv:plugin" # https://python-poetry.org/docs/repositories/ [[tool.poetry.source]] diff --git a/pyproject.toml b/pyproject.toml index e6d3f19c..23009743 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,19 @@ langchain-chatchat = { path = "libs/chatchat-server", develop = true } ipykernel = "^6.29.2" [tool.poetry.group.test.dependencies] - +pytest = "^7.3.0" +pytest-cov = "^4.0.0" +pytest-dotenv = "^0.5.2" +duckdb-engine = "^0.9.2" +pytest-watcher = "^0.2.6" +freezegun = "^1.2.2" +responses = "^0.22.0" +pytest-asyncio = "^0.23.2" +lark = "^1.1.5" +pytest-mock = "^3.10.0" +pytest-socket = "^0.6.0" +syrupy = "^4.0.2" +requests-mock = "^1.11.0" [tool.ruff] extend-include = ["*.ipynb"] @@ -48,11 +60,11 @@ extend-exclude = [ [tool.poetry.plugins.dotenv] ignore = "false" -location = ".env" +dotenv = "dotenv:plugin" # https://python-poetry.org/docs/repositories/ [[tool.poetry.source]] name = "tsinghua" url = "https://pypi.tuna.tsinghua.edu.cn/simple/" -priority = "default" \ No newline at end of file +priority = "default" diff --git a/tools/model_loaders/xinference_manager.py b/tools/model_loaders/xinference_manager.py index c2a86cc0..d628fcb0 100644 --- a/tools/model_loaders/xinference_manager.py +++ b/tools/model_loaders/xinference_manager.py @@ -133,6 +133,8 @@ model_names = list(regs[model_type].keys()) model_name = cols[1].selectbox("模型名称:", model_names) cur_reg = regs[model_type][model_name]["reg"] +model_format = None +model_quant = None if model_type == "LLM": cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg) @@ -217,7 +219,7 @@ cols = st.columns(3) if cols[0].button("设置模型缓存"): if os.path.isabs(local_path) and os.path.isdir(local_path): - cur_spec.model_uri = local_path + cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri if os.path.isdir(cache_dir): os.rmdir(cache_dir) if model_type == "LLM": @@ -250,10 +252,10 @@ if cols[2].button("注册为自定义模型"): if model_type == "LLM": cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}" cur_family.model_family = "other" - model_definition = cur_family.json(indent=2, ensure_ascii=False) + model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False) else: cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}" - model_definition = cur_spec.json(indent=2, ensure_ascii=False) + model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False) client.register_model( model_type=model_type, model=model_definition, @@ -262,4 +264,3 @@ if cols[2].button("注册为自定义模型"): st.rerun() else: st.error("必须输入存在的绝对路径") -