Merge branch 'chatchat-space:dev' into dev

This commit is contained in:
zR 2024-06-10 13:28:15 +08:00 committed by GitHub
commit d189bd182f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 1043 additions and 296 deletions

1
.env Normal file
View File

@ -0,0 +1 @@
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring

View File

@ -15,9 +15,11 @@
Install Poetry: [documentation on how to install it.](https://python-poetry.org/docs/#installing-with-pipx) Install Poetry: [documentation on how to install it.](https://python-poetry.org/docs/#installing-with-pipx)
> 友情提示 不想安装pipx可以用pip安装poetryTips:如果你没有其它poetry的项目
> 注意: 如果您使用 Conda 或 Pyenv 作为您的环境/包管理器在安装Poetry之后 > 注意: 如果您使用 Conda 或 Pyenv 作为您的环境/包管理器在安装Poetry之后
> 使用如下命令使 Poetry 使用 virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`) > 使用如下命令使 Poetry 使用 virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`)
#### 本地开发环境安装 #### 本地开发环境安装
- 选择主项目目录 - 选择主项目目录

View File

@ -0,0 +1 @@
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring

View File

@ -29,6 +29,11 @@ vim model_providers.yaml
> >
> 详细配置请参考[README.md](../model-providers/README.md) > 详细配置请参考[README.md](../model-providers/README.md)
- 初始化知识库
```shell
chatchat-kb -r
```
- 启动服务 - 启动服务
```shell ```shell
chatchat -a chatchat -a

View File

@ -0,0 +1,47 @@
from chatchat.configs import config_workspace as workspace
def main():
import argparse
parser = argparse.ArgumentParser(description="指令` chatchat-config` 工作空间配置")
# 只能选择true或false
parser.add_argument(
"-v",
"--verbose",
choices=["true", "false"],
help="是否开启详细日志"
)
parser.add_argument(
"-d",
"--data",
help="数据存放路径"
)
parser.add_argument(
"-f",
"--format",
help="日志格式"
)
parser.add_argument(
"--clear",
action="store_true",
help="清除配置"
)
args = parser.parse_args()
if args.verbose:
if args.verbose.lower() == "true":
workspace.set_log_verbose(True)
else:
workspace.set_log_verbose(False)
if args.data:
workspace.set_data_path(args.data)
if args.format:
workspace.set_log_format(args.format)
if args.clear:
workspace.clear()
print(workspace.get_config())
if __name__ == "__main__":
main()

View File

@ -1,8 +1,10 @@
import importlib import importlib
import importlib.util import importlib.util
import os import os
from pathlib import Path
from typing import Dict, Any from typing import Dict, Any
import json
import logging import logging
logger = logging.getLogger() logger = logging.getLogger()
@ -38,7 +40,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
) )
user_import = False user_import = False
if user_import: if user_import:
# Dynamic loading {config}.py file # Dynamic loading {config}.py file
py_path = os.path.join(user_config_path, import_config_mod + ".py") py_path = os.path.join(user_config_path, import_config_mod + ".py")
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
@ -69,7 +70,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
) )
raise RuntimeError(f"Failed to load user config from {user_config_path}") raise RuntimeError(f"Failed to load user config from {user_config_path}")
# 当前文件路径 # 当前文件路径
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py") py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
spec = importlib.util.spec_from_file_location(f"*", spec = importlib.util.spec_from_file_location(f"*",
py_path) py_path)
@ -95,75 +96,108 @@ CONFIG_IMPORTS = {
} }
def _import_ConfigBasic() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigBasic = load_mod(basic_config_load.get("module"), "ConfigBasic")
return ConfigBasic
def _import_ConfigBasicFactory() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigBasicFactory = load_mod(basic_config_load.get("module"), "ConfigBasicFactory")
return ConfigBasicFactory
def _import_ConfigWorkSpace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace")
return ConfigWorkSpace
def _import_config_workspace() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod")
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace
def _import_log_verbose() -> Any: def _import_log_verbose() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
log_verbose = load_mod(basic_config_load.get("module"), "log_verbose") config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().log_verbose
return log_verbose
def _import_chatchat_root() -> Any: def _import_chatchat_root() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
CHATCHAT_ROOT = load_mod(basic_config_load.get("module"), "CHATCHAT_ROOT") config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return CHATCHAT_ROOT return config_workspace.get_config().CHATCHAT_ROOT
def _import_data_path() -> Any: def _import_data_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
DATA_PATH = load_mod(basic_config_load.get("module"), "DATA_PATH")
return DATA_PATH config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().DATA_PATH
def _import_img_dir() -> Any: def _import_img_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
IMG_DIR = load_mod(basic_config_load.get("module"), "IMG_DIR")
return IMG_DIR config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().IMG_DIR
def _import_nltk_data_path() -> Any: def _import_nltk_data_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
NLTK_DATA_PATH = load_mod(basic_config_load.get("module"), "NLTK_DATA_PATH") config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return NLTK_DATA_PATH return config_workspace.get_config().NLTK_DATA_PATH
def _import_log_format() -> Any: def _import_log_format() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
LOG_FORMAT = load_mod(basic_config_load.get("module"), "LOG_FORMAT")
return LOG_FORMAT config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().LOG_FORMAT
def _import_log_path() -> Any: def _import_log_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
LOG_PATH = load_mod(basic_config_load.get("module"), "LOG_PATH")
return LOG_PATH config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().LOG_PATH
def _import_media_path() -> Any: def _import_media_path() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
MEDIA_PATH = load_mod(basic_config_load.get("module"), "MEDIA_PATH")
return MEDIA_PATH config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().MEDIA_PATH
def _import_base_temp_dir() -> Any: def _import_base_temp_dir() -> Any:
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py") basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
load_mod = basic_config_load.get("load_mod") load_mod = basic_config_load.get("load_mod")
BASE_TEMP_DIR = load_mod(basic_config_load.get("module"), "BASE_TEMP_DIR") config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
return config_workspace.get_config().BASE_TEMP_DIR
return BASE_TEMP_DIR
def _import_default_knowledge_base() -> Any: def _import_default_knowledge_base() -> Any:
@ -285,6 +319,7 @@ def _import_db_root_path() -> Any:
return DB_ROOT_PATH return DB_ROOT_PATH
def _import_sqlalchemy_database_uri() -> Any: def _import_sqlalchemy_database_uri() -> Any:
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py") kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
load_mod = kb_config_load.get("load_mod") load_mod = kb_config_load.get("load_mod")
@ -478,7 +513,15 @@ def _import_api_server() -> Any:
def __getattr__(name: str) -> Any: def __getattr__(name: str) -> Any:
if name == "log_verbose": if name == "ConfigBasic":
return _import_ConfigBasic()
elif name == "ConfigBasicFactory":
return _import_ConfigBasicFactory()
elif name == "ConfigWorkSpace":
return _import_ConfigWorkSpace()
elif name == "config_workspace":
return _import_config_workspace()
elif name == "log_verbose":
return _import_log_verbose() return _import_log_verbose()
elif name == "CHATCHAT_ROOT": elif name == "CHATCHAT_ROOT":
return _import_chatchat_root() return _import_chatchat_root()
@ -578,6 +621,7 @@ VERSION = "v0.3.0-preview"
__all__ = [ __all__ = [
"VERSION", "VERSION",
"config_workspace",
"log_verbose", "log_verbose",
"CHATCHAT_ROOT", "CHATCHAT_ROOT",
"DATA_PATH", "DATA_PATH",
@ -626,5 +670,8 @@ __all__ = [
"WEBUI_SERVER", "WEBUI_SERVER",
"API_SERVER", "API_SERVER",
"ConfigBasic",
"ConfigBasicFactory",
"ConfigWorkSpace",
] ]

View File

@ -1,55 +1,180 @@
import logging
import os import os
import json
from pathlib import Path from pathlib import Path
import logging
import langchain
# 是否显示详细日志
log_verbose = False
langchain.verbose = False
# 通常情况下不需要更改以下内容
# chatchat 项目根目录
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
# 用户数据根目录
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
if not os.path.exists(DATA_PATH):
os.mkdir(DATA_PATH)
# 项目相关图片
IMG_DIR = os.path.join(CHATCHAT_ROOT, "img")
if not os.path.exists(IMG_DIR):
os.mkdir(IMG_DIR)
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 日志存储路径 class ConfigBasic:
LOG_PATH = os.path.join(DATA_PATH, "logs") log_verbose: bool
if not os.path.exists(LOG_PATH): """是否开启日志详细信息"""
os.mkdir(LOG_PATH) CHATCHAT_ROOT: str
"""项目根目录"""
DATA_PATH: str
"""用户数据根目录"""
IMG_DIR: str
"""项目相关图片"""
NLTK_DATA_PATH: str
"""nltk 模型存储路径"""
LOG_FORMAT: str
"""日志格式"""
LOG_PATH: str
"""日志存储路径"""
MEDIA_PATH: str
"""模型生成内容(图片、视频、音频等)保存位置"""
BASE_TEMP_DIR: str
"""临时文件目录,主要用于文件对话"""
# 模型生成内容(图片、视频、音频等)保存位置 def __str__(self):
MEDIA_PATH = os.path.join(DATA_PATH, "media") return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})"
if not os.path.exists(MEDIA_PATH):
os.mkdir(MEDIA_PATH)
os.mkdir(os.path.join(MEDIA_PATH, "image"))
os.mkdir(os.path.join(MEDIA_PATH, "audio"))
os.mkdir(os.path.join(MEDIA_PATH, "video"))
# 临时文件目录,主要用于文件对话
BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp") class ConfigBasicFactory:
if not os.path.exists(BASE_TEMP_DIR): """Basic config for ChatChat """
os.mkdir(BASE_TEMP_DIR)
def __init__(self):
# 日志格式
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logging.basicConfig(format=self.LOG_FORMAT)
self.LOG_VERBOSE = False
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
# 用户数据根目录
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
if not os.path.exists(self._DATA_PATH):
os.makedirs(self._DATA_PATH, exist_ok=True)
self._init_data_dir()
# 项目相关图片
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
if not os.path.exists(self.IMG_DIR):
os.makedirs(self.IMG_DIR, exist_ok=True)
def log_verbose(self, verbose: bool):
self.LOG_VERBOSE = verbose
def data_path(self, path: str):
self.DATA_PATH = path
if not os.path.exists(self.DATA_PATH):
os.makedirs(self.DATA_PATH, exist_ok=True)
# 复制_DATA_PATH数据到DATA_PATH
if self._DATA_PATH != self.DATA_PATH:
os.system(f"cp -r {self._DATA_PATH}/* {self.DATA_PATH}")
self._init_data_dir()
def log_format(self, log_format: str):
self.LOG_FORMAT = log_format
logging.basicConfig(format=self.LOG_FORMAT)
def _init_data_dir(self):
logger.info(f"Init data dir: {self.DATA_PATH}")
# nltk 模型存储路径
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
import nltk
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
# 日志存储路径
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
if not os.path.exists(self.LOG_PATH):
os.makedirs(self.LOG_PATH, exist_ok=True)
# 模型生成内容(图片、视频、音频等)保存位置
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
if not os.path.exists(self.MEDIA_PATH):
os.makedirs(self.MEDIA_PATH, exist_ok=True)
os.makedirs(os.path.join(self.MEDIA_PATH, "image"), exist_ok=True)
os.makedirs(os.path.join(self.MEDIA_PATH, "audio"), exist_ok=True)
os.makedirs(os.path.join(self.MEDIA_PATH, "video"), exist_ok=True)
# 临时文件目录,主要用于文件对话
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
if not os.path.exists(self.BASE_TEMP_DIR):
os.makedirs(self.BASE_TEMP_DIR, exist_ok=True)
logger.info(f"Init data dir: {self.DATA_PATH} success.")
def get_config(self) -> ConfigBasic:
config = ConfigBasic()
config.log_verbose = self.LOG_VERBOSE
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
config.DATA_PATH = self.DATA_PATH
config.IMG_DIR = self.IMG_DIR
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
config.LOG_FORMAT = self.LOG_FORMAT
config.LOG_PATH = self.LOG_PATH
config.MEDIA_PATH = self.MEDIA_PATH
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
return config
class ConfigWorkSpace:
"""
工作空间的配置预设提供ConfigBasic建造方法产生实例
该类的实例对象用于存储工作空间的配置信息如工作空间的路径等
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中
注意不存在则读取默认
"""
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
def __init__(self):
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
if not os.path.exists(self.workspace):
os.makedirs(self.workspace, exist_ok=True)
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
# 初始化工作空间配置转换成json格式实现ConfigBasic的实例化
config_json = self._load_config()
if config_json:
_config_factory = ConfigBasicFactory()
if config_json.get("log_verbose"):
_config_factory.log_verbose(config_json.get("log_verbose"))
if config_json.get("DATA_PATH"):
_config_factory.data_path(config_json.get("DATA_PATH"))
if config_json.get("LOG_FORMAT"):
_config_factory.log_format(config_json.get("LOG_FORMAT"))
self._config_factory = _config_factory
def get_config(self) -> ConfigBasic:
return self._config_factory.get_config()
def set_log_verbose(self, verbose: bool):
self._config_factory.log_verbose(verbose)
self._store_config()
def set_data_path(self, path: str):
self._config_factory.data_path(path)
self._store_config()
def set_log_format(self, log_format: str):
self._config_factory.log_format(log_format)
self._store_config()
def clear(self):
logger.info("Clear workspace config.")
os.remove(self.workspace_config)
def _load_config(self):
try:
with open(self.workspace_config, "r") as f:
return json.loads(f.read())
except FileNotFoundError:
return None
def _store_config(self):
with open(self.workspace_config, "w") as f:
config = self._config_factory.get_config()
config_json = {
"log_verbose": config.log_verbose,
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
"DATA_PATH": config.DATA_PATH,
"LOG_FORMAT": config.LOG_FORMAT
}
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
config_workspace: ConfigWorkSpace = ConfigWorkSpace()

View File

@ -1,11 +1,13 @@
import os import os
from pathlib import Path from pathlib import Path
# chatchat 项目根目录 import sys
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent) sys.path.append(str(Path(__file__).parent))
from _basic_config import config_workspace
# 用户数据根目录 # 用户数据根目录
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data") DATA_PATH = config_workspace.get_config().DATA_PATH
# 默认使用的知识库 # 默认使用的知识库
DEFAULT_KNOWLEDGE_BASE = "samples" DEFAULT_KNOWLEDGE_BASE = "samples"

View File

@ -118,6 +118,25 @@ MODEL_PLATFORMS = [
"tts_models": [], "tts_models": [],
}, },
{
"platform_name": "xinference",
"platform_type": "xinference",
"api_base_url": "http://127.0.0.1:9997/v1",
"api_key": "EMPTY",
"api_concurrencies": 5,
"llm_models": [
"glm-4",
"qwen2-instruct",
"qwen1.5-chat",
],
"embed_models": [
"bge-large-zh-v1.5",
],
"image_models": [],
"reranking_models": [],
"speech2text_models": [],
"tts_models": [],
},
] ]
@ -130,7 +149,7 @@ TOOL_CONFIG = {
"search_local_knowledgebase": { "search_local_knowledgebase": {
"use": False, "use": False,
"top_k": 3, "top_k": 3,
"score_threshold": 1, "score_threshold": 1.0,
"conclude_prompt": { "conclude_prompt": {
"with_result": "with_result":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"' '<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"'
@ -208,5 +227,34 @@ TOOL_CONFIG = {
"text2images": { "text2images": {
"use": False, "use": False,
}, },
# text2sql使用建议
# 1、因大模型生成的sql可能与预期有偏差请务必在测试环境中进行充分测试、评估
# 2、生产环境中对于查询操作由于不确定查询效率推荐数据库采用主从数据库架构让text2sql连接从数据库防止可能的慢查询影响主业务
# 3、对于写操作应保持谨慎如不需要写操作设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作;
# 4、text2sql与大模型在意图理解、sql转换等方面的能力有关可切换不同大模型进行测试
# 5、数据库表名、字段名应与其实际作用保持一致、容易理解且应对数据库表名、字段进行详细的备注说明帮助大模型更好理解数据库结构
# 6、若现有数据库表名难于让大模型理解可配置下面table_comments字段补充说明某些表的作用。
"text2sql": {
"use": False,
# SQLAlchemy连接字符串支持的数据库有
# crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb
# 不同的数据库请查询SQLAlchemy修改sqlalchemy_connect_str配置对应的数据库连接如sqlite为sqlite:///数据库文件路径下面示例为mysql
# 如提示缺少对应数据库的驱动请自行通过poetry安装
"sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e",
# 务必评估是否需要开启read_only,开启后会对sql语句进行检查请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求
# 优先推荐从数据库层面对用户权限进行限制
"read_only": False,
#限定返回的行数
"top_k":50,
#是否返回中间步骤
"return_intermediate_steps": True,
#如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表
"table_names":[],
#对表名进行额外说明辅助大模型更好的判断应该使用哪些表尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。
"table_comments":{
# 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明
# "tableA":"这是一个用户表,存储了用户的基本信息",
# "tanleB":"角色表",
}
},
} }

View File

@ -20,16 +20,16 @@
xinference: xinference:
model_credential: model_credential:
- model: 'chatglm3-6b' - model: 'glm-4'
model_type: 'llm' model_type: 'llm'
model_credentials: model_credentials:
server_url: 'http://127.0.0.1:9997/' server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b' model_uid: 'glm-4'
- model: 'Qwen1.5-14B-Chat' - model: 'qwen1.5-chat'
model_type: 'llm' model_type: 'llm'
model_credentials: model_credentials:
server_url: 'http://127.0.0.1:9997/' server_url: 'http://127.0.0.1:9997/'
model_uid: 'Qwen1.5-14B-Chat' model_uid: 'qwen1.5-chat'
- model: 'bge-large-zh-v1.5' - model: 'bge-large-zh-v1.5'
model_type: 'embeddings' model_type: 'embeddings'
model_credentials: model_credentials:

View File

@ -1,13 +1,12 @@
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作 # Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
from datetime import datetime
import multiprocessing as mp
from typing import Dict from typing import Dict
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files) folder2db, prune_db_docs, prune_folder_files)
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger
import multiprocessing as mp
import logging
logger = logging.getLogger(__name__)
from datetime import datetime
def run_init_model_provider( def run_init_model_provider(
@ -34,7 +33,7 @@ def run_init_model_provider(
provider_port=provider_port) provider_port=provider_port)
if __name__ == "__main__": def main():
import argparse import argparse
parser = argparse.ArgumentParser(description="please specify only one operate method once time.") parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
@ -186,3 +185,7 @@ if __name__ == "__main__":
for p in processes.values(): for p in processes.values():
logger.info("Process status: %s", p) logger.info("Process status: %s", p)
if __name__ == "__main__":
main()

View File

@ -10,3 +10,4 @@ from .text2image import text2images
from .vqa_processor import vqa_processor from .vqa_processor import vqa_processor
from .aqa_processor import aqa_processor from .aqa_processor import aqa_processor
from .text2sql import text2sql

View File

@ -1,13 +1,14 @@
from urllib.parse import urlencode from urllib.parse import urlencode
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool, BaseToolOutput from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput
from chatchat.server.knowledge_base.kb_api import list_kbs from chatchat.server.knowledge_base.kb_api import list_kbs
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
from chatchat.configs import KB_INFO from chatchat.configs import KB_INFO
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get informationOnly local data on "
"this knowledge use this tool. The 'database' should be one of the above [{key}].")
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
template_knowledge = template.format(KB_info=KB_info_str, key="samples") template_knowledge = template.format(KB_info=KB_info_str, key="samples")
@ -49,7 +50,7 @@ def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]), database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
query: str = Field(description="Query for Knowledge Search"), query: str = Field(description="Query for Knowledge Search"),
): ):
'''''' """"""
tool_config = get_tool_config("search_local_knowledgebase") tool_config = get_tool_config("search_local_knowledgebase")
ret = search_knowledgebase(query=query, database=database, config=tool_config) ret = search_knowledgebase(query=query, database=database, config=tool_config)
return KBToolOutput(ret, database=database) return KBToolOutput(ret, database=database)

View File

@ -0,0 +1,103 @@
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialChain
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool, BaseToolOutput
from sqlalchemy.exc import OperationalError
from sqlalchemy import event
from langchain_core.prompts.prompt import PromptTemplate
from langchain.chains import LLMChain
READ_ONLY_PROMPT_TEMPLATE="""You are a MySQL expert. The database is currently in read-only mode.
Given an input question, determine if the related SQL can be executed in read-only mode.
If the SQL can be executed normally, return Answer:'SQL can be executed normally'.
If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'.
Use the following format:
Answer: Final answer here
Question: {query}
"""
# 定义一个拦截器函数来检查SQL语句以支持read-only,可修改下面的write_operations以匹配你使用的数据库写操作关键字
def intercept_sql(conn, cursor, statement, parameters, context, executemany):
# List of SQL keywords that indicate a write operation
write_operations = ("insert", "update", "delete", "create", "drop", "alter", "truncate", "rename")
# Check if the statement starts with any of the write operation keywords
if any(statement.strip().lower().startswith(op) for op in write_operations):
raise OperationalError("Database is read-only. Write operations are not allowed.", params=None, orig=None)
def query_database(query: str,
config: dict):
top_k = config["top_k"]
return_intermediate_steps = config["return_intermediate_steps"]
sqlalchemy_connect_str = config["sqlalchemy_connect_str"]
read_only = config["read_only"]
db = SQLDatabase.from_uri(sqlalchemy_connect_str)
from chatchat.server.api_server.chat_routes import global_model_name
from chatchat.server.utils import get_ChatOpenAI
llm = get_ChatOpenAI(
model_name=global_model_name,
temperature=0,
streaming=True,
local_wrap=True,
verbose=True
)
table_names=config["table_names"]
table_comments=config["table_comments"]
result = None
#如果发现大模型判断用什么表出现问题尝试给langchain提供额外的表说明辅助大模型更好的判断应该使用哪些表尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判
#由于langchain固定了输入参数所以只能通过query传递额外的表说明
if table_comments:
TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n"
table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])
query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n"
if read_only:
# 在read_only下先让大模型判断只读模式是否能满足需求避免后续执行过程报错返回友好提示。
READ_ONLY_PROMPT = PromptTemplate(
input_variables=["query"],
template=READ_ONLY_PROMPT_TEMPLATE,
)
read_only_chain = LLMChain(
prompt=READ_ONLY_PROMPT,
llm=llm,
)
read_only_result = read_only_chain.invoke(query)
if "SQL cannot be executed normally" in read_only_result["text"]:
return "当前数据库为只读状态,无法满足您的需求!"
# 当然大模型不能保证完全判断准确,为防止大模型判断有误,再从拦截器层面拒绝写操作
event.listen(db._engine, "before_cursor_execute", intercept_sql)
#如果不指定table_names优先走SQLDatabaseSequentialChain这个链会先预测需要哪些表然后再将相关表输入SQLDatabaseChain
#这是因为如果不指定table_names直接走SQLDatabaseChainLangchain会将全量表结构传递给大模型可能会因token太长从而引发错误也浪费资源
#如果指定了table_names直接走SQLDatabaseChain将特定表结构传递给大模型进行判断
if len(table_names) > 0:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
result = db_chain.invoke({"query":query,"table_names_to_use":table_names})
else:
#先预测会使用哪些表,然后再将问题和预测的表给大模型
db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
result = db_chain.invoke(query)
context = f"""查询结果:{result['result']}\n\n"""
intermediate_steps=result["intermediate_steps"]
#如果存在intermediate_steps且这个数组的长度大于2则保留最后两个元素因为前面几个步骤存在示例数据容易引起误解
if intermediate_steps:
if len(intermediate_steps)>2:
sql_detail=intermediate_steps[-2:-1][0]["input"]
# sql_detail截取从SQLQuery到Answer:之间的内容
sql_detail=sql_detail[sql_detail.find("SQLQuery:")+9:sql_detail.find("Answer:")]
context = context+"执行的sql:'"+sql_detail+"'\n\n"
return context
@regist_tool(title="Text2Sql")
def text2sql(query: str = Field(description="No need for SQL statements,just input the natural language that you want to chat with database")):
'''Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result.'''
tool_config = get_tool_config("text2sql")
return BaseToolOutput(query_database(query=query, config=tool_config))

View File

@ -28,6 +28,8 @@ chat_router.post("/file_chat",
summary="文件对话" summary="文件对话"
)(file_chat) )(file_chat)
#定义全局model信息用于给Text2Sql中的get_ChatOpenAI提供model_name
global_model_name=None
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口") @chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
async def chat_completions( async def chat_completions(
@ -51,6 +53,8 @@ async def chat_completions(
for key in list(extra): for key in list(extra):
delattr(body, key) delattr(body, key)
global global_model_name
global_model_name=body.model
# check tools & tool_choice in request body # check tools & tool_choice in request body
if isinstance(body.tool_choice, str): if isinstance(body.tool_choice, str):
if t := get_tool(body.tool_choice): if t := get_tool(body.tool_choice):

View File

@ -63,10 +63,10 @@ def upload_temp_docs(
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
) -> BaseResponse: ) -> BaseResponse:
''' """
将文件保存到临时目录并进行向量化 将文件保存到临时目录并进行向量化
返回临时目录名称作为ID同时也是临时向量库的ID 返回临时目录名称作为ID同时也是临时向量库的ID
''' """
if prev_id is not None: if prev_id is not None:
memo_faiss_pool.pop(prev_id) memo_faiss_pool.pop(prev_id)
@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
docs = [x[0] for x in docs] docs = [x[0] for x in docs]
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: ## 如果没有找到相关文档使用Empty模板 if len(docs) == 0: # 如果没有找到相关文档使用Empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty") prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else: else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)

View File

@ -1,6 +1,6 @@
from typing import List from typing import List
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
from chatchat.server.document_loaders.ocr import get_ocr from chatchat.server.file_rag.document_loaders.ocr import get_ocr
class RapidOCRLoader(UnstructuredFileLoader): class RapidOCRLoader(UnstructuredFileLoader):

View File

@ -4,7 +4,7 @@ import cv2
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from chatchat.configs import PDF_OCR_THRESHOLD from chatchat.configs import PDF_OCR_THRESHOLD
from chatchat.server.document_loaders.ocr import get_ocr from chatchat.server.file_rag.document_loaders.ocr import get_ocr
import tqdm import tqdm

View File

@ -0,0 +1,3 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService

View File

@ -0,0 +1,24 @@
from langchain.vectorstores import VectorStore
from abc import ABCMeta, abstractmethod
class BaseRetrieverService(metaclass=ABCMeta):
def __init__(self, **kwargs):
self.do_init(**kwargs)
@abstractmethod
def do_init(self, **kwargs):
pass
@abstractmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
pass
@abstractmethod
def get_relevant_documents(self, query: str):
pass

View File

@ -0,0 +1,47 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
class EnsembleRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
faiss_retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
# TODO: 换个不用torch的实现方式
from cutword.cutword import Cutter
cutter = Cutter()
docs = list(vectorstore.docstore._dict.values())
bm25_retriever = BM25Retriever.from_documents(
docs,
preprocess_func=cutter.cutword
)
bm25_retriever.k = top_k
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)
return EnsembleRetrieverService(retriever=ensemble_retriever)
def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,33 @@
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever
class VectorstoreRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever
@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": score_threshold,
"k": top_k
}
)
return VectorstoreRetrieverService(retriever=retriever)
def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[:self.top_k]

View File

@ -0,0 +1,13 @@
from chatchat.server.file_rag.retrievers import (
BaseRetrieverService,
VectorstoreRetrieverService,
EnsembleRetrieverService,
)
Retrivals = {
"vectorstore": VectorstoreRetrieverService,
"ensemble": EnsembleRetrieverService,
}
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
return Retrivals[type]

View File

@ -14,6 +14,7 @@ def list_kbs():
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"), vector_store_type: str = Body("faiss"),
kb_info: str = Body("", description="知识库内容简介用于Agent选择知识库。"),
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
) -> BaseResponse: ) -> BaseResponse:
# Create selected knowledge base # Create selected knowledge base
@ -26,7 +27,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
if kb is not None: if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info)
try: try:
kb.create_kb() kb.create_kb()
except Exception as e: except Exception as e:

View File

@ -95,6 +95,7 @@ class KBFaissPool(_FaissPool):
embed_model: str = DEFAULT_EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
) -> ThreadSafeFaiss: ) -> ThreadSafeFaiss:
self.atomic.acquire() self.atomic.acquire()
locked = True
vector_name = vector_name or embed_model vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
try: try:
@ -103,12 +104,14 @@ class KBFaissPool(_FaissPool):
self.set((kb_name, vector_name), item) self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
locked = False
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.") logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name) vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")): if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = get_Embeddings(embed_model=embed_model) embeddings = get_Embeddings(embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,
allow_dangerous_deserialization=True)
elif create: elif create:
# create an empty vector store # create an empty vector store
if not os.path.exists(vs_path): if not os.path.exists(vs_path):
@ -121,8 +124,10 @@ class KBFaissPool(_FaissPool):
item.finish_loading() item.finish_loading()
else: else:
self.atomic.release() self.atomic.release()
locked = False
except Exception as e: except Exception as e:
self.atomic.release() if locked: # we don't know exception raised before or after atomic.release
self.atomic.release()
logger.error(e) logger.error(e)
raise RuntimeError(f"向量库 {kb_name} 加载失败。") raise RuntimeError(f"向量库 {kb_name} 加载失败。")
return self.get((kb_name, vector_name)) return self.get((kb_name, vector_name))

View File

@ -1,21 +1,23 @@
import json
import os import os
import urllib import urllib
from typing import List, Dict
from fastapi import File, Form, Body, Query, UploadFile from fastapi import File, Form, Body, Query, UploadFile
from fastapi.responses import FileResponse
from langchain.docstore.document import Document
from sse_starlette import EventSourceResponse
from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL, from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose, ) logger, log_verbose, )
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path, from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
files2docs_in_thread, KnowledgeFile) files2docs_in_thread, KnowledgeFile)
from fastapi.responses import FileResponse from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details
from sse_starlette import EventSourceResponse
import json
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
from langchain.docstore.document import Document
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from typing import List, Dict from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model
def search_docs( def search_docs(
@ -35,7 +37,8 @@ def search_docs(
if kb is not None: if kb is not None:
if query: if query:
docs = kb.search_docs(query, top_k, score_threshold) docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs]
elif file_name or metadata: elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata) data = kb.list_docs(file_name=file_name, metadata=metadata)
for d in data: for d in data:
@ -55,8 +58,8 @@ def list_files(
if kb is None: if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else: else:
all_doc_names = kb.list_files() all_docs = get_kb_file_details(knowledge_base_name)
return ListResponse(data=all_doc_names) return ListResponse(data=all_docs)
def _save_files_in_thread(files: List[UploadFile], def _save_files_in_thread(files: List[UploadFile],
@ -352,38 +355,42 @@ def recreate_vector_store(
if not kb.exists() and not allow_empty_kb: if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"} yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else: else:
if kb.exists(): error_msg = f"could not recreate vector store because failed to access embed model."
kb.clear_vs() if not kb.check_embed_model(error_msg):
kb.create_kb() yield {"code": 404, "msg": error_msg}
files = list_files_from_folder(knowledge_base_name) else:
kb_files = [(file, knowledge_base_name) for file in files] if kb.exists():
i = 0 kb.clear_vs()
for status, result in files2docs_in_thread(kb_files, kb.create_kb()
chunk_size=chunk_size, files = list_files_from_folder(knowledge_base_name)
chunk_overlap=chunk_overlap, kb_files = [(file, knowledge_base_name) for file in files]
zh_title_enhance=zh_title_enhance): i = 0
if status: for status, result in files2docs_in_thread(kb_files,
kb_name, file_name, docs = result chunk_size=chunk_size,
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) chunk_overlap=chunk_overlap,
kb_file.splited_docs = docs zh_title_enhance=zh_title_enhance):
yield json.dumps({ if status:
"code": 200, kb_name, file_name, docs = result
"msg": f"({i + 1} / {len(files)}): {file_name}", kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
"total": len(files), kb_file.splited_docs = docs
"finished": i + 1, yield json.dumps({
"doc": file_name, "code": 200,
}, ensure_ascii=False) "msg": f"({i + 1} / {len(files)}): {file_name}",
kb.add_doc(kb_file, not_refresh_vs_cache=True) "total": len(files),
else: "finished": i + 1,
kb_name, file_name, error = result "doc": file_name,
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" }, ensure_ascii=False)
logger.error(msg) kb.add_doc(kb_file, not_refresh_vs_cache=True)
yield json.dumps({ else:
"code": 500, kb_name, file_name, error = result
"msg": msg, msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
}) logger.error(msg)
i += 1 yield json.dumps({
if not not_refresh_vs_cache: "code": 500,
kb.save_vector_store() "msg": msg,
})
i += 1
if not not_refresh_vs_cache:
kb.save_vector_store()
return EventSourceResponse(output()) return EventSourceResponse(output())

View File

@ -5,6 +5,11 @@ import os
from pathlib import Path from pathlib import Path
from langchain.docstore.document import Document from langchain.docstore.document import Document
from typing import List, Union, Dict, Optional, Tuple
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
DEFAULT_EMBEDDING_MODEL, KB_INFO, logger)
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
from chatchat.server.db.repository.knowledge_base_repository import ( from chatchat.server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail, load_kb_from_db, get_kb_detail,
@ -14,18 +19,12 @@ from chatchat.server.db.repository.knowledge_file_repository import (
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db, count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
list_docs_from_db, list_docs_from_db,
) )
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
DEFAULT_EMBEDDING_MODEL, KB_INFO)
from chatchat.server.knowledge_base.utils import ( from chatchat.server.knowledge_base.utils import (
get_kb_path, get_doc_path, KnowledgeFile, get_kb_path, get_doc_path, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,
) )
from typing import List, Union, Dict, Optional, Tuple
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema from chatchat.server.utils import check_embed_model as _check_embed_model
class SupportedVSType: class SupportedVSType:
FAISS = 'faiss' FAISS = 'faiss'
@ -41,10 +40,11 @@ class KBService(ABC):
def __init__(self, def __init__(self,
knowledge_base_name: str, knowledge_base_name: str,
kb_info: str = None,
embed_model: str = DEFAULT_EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
): ):
self.kb_name = knowledge_base_name self.kb_name = knowledge_base_name
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
self.embed_model = embed_model self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name) self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name)
@ -59,6 +59,13 @@ class KBService(ABC):
''' '''
pass pass
def check_embed_model(self, error_msg: str) -> bool:
if not _check_embed_model(self.embed_model):
logger.error(error_msg, exc_info=True)
return False
else:
return True
def create_kb(self): def create_kb(self):
""" """
创建知识库 创建知识库
@ -93,6 +100,9 @@ class KBService(ABC):
向知识库添加文件 向知识库添加文件
如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True 如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True
""" """
if not self.check_embed_model(f"could not add docs because failed to access embed model."):
return False
if docs: if docs:
custom_docs = True custom_docs = True
else: else:
@ -143,6 +153,9 @@ class KBService(ABC):
使用content中的文件更新向量库 使用content中的文件更新向量库
如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True 如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True
""" """
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
return False
if os.path.exists(kb_file.filepath): if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs) self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, docs=docs, **kwargs) return self.add_doc(kb_file, docs=docs, **kwargs)
@ -162,6 +175,8 @@ class KBService(ABC):
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
) ->List[Document]: ) ->List[Document]:
if not self.check_embed_model(f"could not search docs because failed to access embed model."):
return []
docs = self.do_search(query, top_k, score_threshold) docs = self.do_search(query, top_k, score_threshold)
return docs return docs
@ -176,6 +191,9 @@ class KBService(ABC):
传入参数为 {doc_id: Document, ...} 传入参数为 {doc_id: Document, ...}
如果对应 doc_id 的值为 None或其 page_content 为空则删除该文档 如果对应 doc_id 的值为 None或其 page_content 为空则删除该文档
''' '''
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
return False
self.del_doc_by_ids(list(docs.keys())) self.del_doc_by_ids(list(docs.keys()))
docs = [] docs = []
ids = [] ids = []
@ -282,31 +300,32 @@ class KBServiceFactory:
def get_service(kb_name: str, def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType], vector_store_type: Union[str, SupportedVSType],
embed_model: str = DEFAULT_EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
kb_info: str = None,
) -> KBService: ) -> KBService:
if isinstance(vector_store_type, str): if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info}
if SupportedVSType.FAISS == vector_store_type: if SupportedVSType.FAISS == vector_store_type:
from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name, embed_model=embed_model) return FaissKBService(**params)
elif SupportedVSType.PG == vector_store_type: elif SupportedVSType.PG == vector_store_type:
from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService
return PGKBService(kb_name, embed_model=embed_model) return PGKBService(**params)
elif SupportedVSType.MILVUS == vector_store_type: elif SupportedVSType.MILVUS == vector_store_type:
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, embed_model=embed_model) return MilvusKBService(**params)
elif SupportedVSType.ZILLIZ == vector_store_type: elif SupportedVSType.ZILLIZ == vector_store_type:
from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model) return ZillizKBService(**params)
elif SupportedVSType.DEFAULT == vector_store_type: elif SupportedVSType.DEFAULT == vector_store_type:
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type: elif SupportedVSType.ES == vector_store_type:
from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model) return ESKBService(**params)
elif SupportedVSType.CHROMADB == vector_store_type: elif SupportedVSType.CHROMADB == vector_store_type:
from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model) return ChromaKBService(**params)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name) return DefaultKBService(kb_name)

View File

@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from chatchat.server.file_rag.utils import get_Retriever
def _get_result_to_documents(get_result: GetResult) -> List[Document]: def _get_result_to_documents(get_result: GetResult) -> List[Document]:
@ -75,10 +76,13 @@ class ChromaKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[ def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
Tuple[Document, float]]: Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model) retriever = get_Retriever("vectorstore").from_vectorstore(
embeddings = embed_func.embed_query(query) self.collection,
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k) top_k=top_k,
return _results_to_docs_and_scores(query_result) score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
doc_infos = [] doc_infos = []

View File

@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from elasticsearch import Elasticsearch, BadRequestError from elasticsearch import Elasticsearch, BadRequestError
from chatchat.configs import kbs_config, KB_ROOT_PATH from chatchat.configs import kbs_config, KB_ROOT_PATH
from chatchat.server.file_rag.utils import get_Retriever
import logging import logging
@ -107,8 +108,12 @@ class ESKBService(KBService):
def do_search(self, query:str, top_k: int, score_threshold: float): def do_search(self, query:str, top_k: int, score_threshold: float):
# 文本相似性检索 # 文本相似性检索
docs = self.db.similarity_search_with_score(query=query, retriever = get_Retriever("vectorstore").from_vectorstore(
k=top_k) self.db,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:

View File

@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from chatchat.server.utils import get_Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Tuple
from chatchat.server.file_rag.utils import get_Retriever
class FaissKBService(KBService): class FaissKBService(KBService):
@ -62,10 +62,13 @@ class FaissKBService(KBService):
top_k: int, top_k: int,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) retriever = get_Retriever("ensemble").from_vectorstore(
vs,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs return docs
def do_add_doc(self, def do_add_doc(self,

View File

@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
score_threshold_process score_threshold_process
from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.file_rag.utils import get_Retriever
class MilvusKBService(KBService): class MilvusKBService(KBService):
@ -67,10 +67,16 @@ class MilvusKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus() self._load_milvus()
embed_func = get_Embeddings(self.embed_model) # embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query) # embeddings = embed_func.embed_query(query)
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) # docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs) retriever = get_Retriever("vectorstore").from_vectorstore(
self.milvus,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:

View File

@ -15,6 +15,7 @@ import shutil
import sqlalchemy import sqlalchemy
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from chatchat.server.file_rag.utils import get_Retriever
class PGKBService(KBService): class PGKBService(KBService):
@ -60,10 +61,13 @@ class PGKBService(KBService):
shutil.rmtree(self.kb_path) shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
embed_func = get_Embeddings(self.embed_model) retriever = get_Retriever("vectorstore").from_vectorstore(
embeddings = embed_func.embed_query(query) self.pg_vector,
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) top_k=top_k,
return score_threshold_process(score_threshold, top_k, docs) score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs) ids = self.pg_vector.add_documents(docs)

View File

@ -1,5 +1,4 @@
from typing import List, Dict, Optional from typing import List, Dict
from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
from langchain.vectorstores import Zilliz from langchain.vectorstores import Zilliz
from chatchat.configs import kbs_config from chatchat.configs import kbs_config
@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
score_threshold_process score_threshold_process
from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from chatchat.server.file_rag.utils import get_Retriever
class ZillizKBService(KBService): class ZillizKBService(KBService):
@ -60,10 +60,13 @@ class ZillizKBService(KBService):
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_zilliz() self._load_zilliz()
embed_func = get_Embeddings(self.embed_model) retriever = get_Retriever("vectorstore").from_vectorstore(
embeddings = embed_func.embed_query(query) self.zilliz,
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k) top_k=top_k,
return score_threshold_process(score_threshold, top_k, docs) score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:

View File

@ -42,55 +42,59 @@ def recreate_summary_vector_store(
if not kb.exists() and not allow_empty_kb: if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"} yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else: else:
# 重新创建知识库 error_msg = f"could not recreate summary vector store because failed to access embed model."
kb_summary = KBSummaryService(knowledge_base_name, embed_model) if not kb.check_embed_model(error_msg):
kb_summary.drop_kb_summary() yield {"code": 404, "msg": error_msg}
kb_summary.create_kb_summary() else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
llm = get_ChatOpenAI( llm = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
local_wrap=True, local_wrap=True,
) )
reduce_llm = get_ChatOpenAI( reduce_llm = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
local_wrap=True, local_wrap=True,
) )
# 文本摘要适配器 # 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm, summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm, reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE) overlap_size=OVERLAP_SIZE)
files = list_files_from_folder(knowledge_base_name) files = list_files_from_folder(knowledge_base_name)
i = 0 i = 0
for i, file_name in enumerate(files): for i, file_name in enumerate(files):
doc_infos = kb.list_docs(file_name=file_name) doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description, docs = summary.summarize(file_description=file_description,
docs=doc_infos) docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary: if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成") logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
yield json.dumps({ yield json.dumps({
"code": 200, "code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}", "msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files), "total": len(files),
"finished": i + 1, "finished": i + 1,
"doc": file_name, "doc": file_name,
}, ensure_ascii=False) }, ensure_ascii=False)
else: else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg) logger.error(msg)
yield json.dumps({ yield json.dumps({
"code": 500, "code": 500,
"msg": msg, "msg": msg,
}) })
i += 1 i += 1
return EventSourceResponse(output()) return EventSourceResponse(output())

View File

@ -1,6 +1,6 @@
from chatchat.configs import ( from chatchat.configs import (
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
CHUNK_SIZE, OVERLAP_SIZE CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose
) )
from chatchat.server.knowledge_base.utils import ( from chatchat.server.knowledge_base.utils import (
get_file_path, list_kbs_from_folder, get_file_path, list_kbs_from_folder,

View File

@ -10,7 +10,7 @@ from chatchat.configs import (
TEXT_SPLITTER_NAME, TEXT_SPLITTER_NAME,
) )
import importlib import importlib
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain_community.document_loaders import langchain_community.document_loaders
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter

View File

@ -244,6 +244,16 @@ def get_Embeddings(
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True) logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool:
embeddings = get_Embeddings(embed_model=embed_model)
try:
embeddings.embed_query("this is a test")
return True
except Exception as e:
logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True)
return False
def get_OpenAIClient( def get_OpenAIClient(
platform_name: str = None, platform_name: str = None,
model_name: str = None, model_name: str = None,
@ -696,7 +706,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
''' '''
创建一个临时目录返回路径文件夹名称 创建一个临时目录返回路径文件夹名称
''' '''
from chatchat.configs.basic_config import BASE_TEMP_DIR from chatchat.configs import BASE_TEMP_DIR
import uuid import uuid
if id is not None: # 如果指定的临时目录已存在,直接返回 if id is not None: # 如果指定的临时目录已存在,直接返回

View File

@ -24,28 +24,28 @@ chat_box = ChatBox(
def save_session(): def save_session():
'''save session state to chat context''' """save session state to chat context"""
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"]) chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def restore_session(): def restore_session():
'''restore sesstion state from chat context''' """restore sesstion state from chat context"""
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"]) chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
def rerun(): def rerun():
''' """
save chat context before rerun save chat context before rerun
''' """
save_session() save_session()
st.rerun() st.rerun()
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
''' """
返回消息历史 返回消息历史
content_in_expander控制是否返回expander元素中的内容一般导出的时候可以选上传入LLM的history不需要 content_in_expander控制是否返回expander元素中的内容一般导出的时候可以选上传入LLM的history不需要
''' """
def filter(msg): def filter(msg):
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]] content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
@st.cache_data @st.cache_data
def upload_temp_docs(files, _api: ApiRequest) -> str: def upload_temp_docs(files, _api: ApiRequest) -> str:
''' """
将文件上传到临时目录用于文件对话 将文件上传到临时目录用于文件对话
返回临时向量库ID 返回临时向量库ID
''' """
return _api.upload_temp_docs(files).get("data", {}).get("id") return _api.upload_temp_docs(files).get("data", {}).get("id")
@ -157,11 +157,13 @@ def dialogue_page(
tools = list_tools(api) tools = list_tools(api)
tool_names = ["None"] + list(tools) tool_names = ["None"] + list(tools)
if use_agent: if use_agent:
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools") # selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# check_all=True, key="selected_tools")
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"], selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
key="selected_tools") key="selected_tools")
else: else:
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool") # selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# key="selected_tool")
selected_tool = st.selectbox("选择工具", tool_names, selected_tool = st.selectbox("选择工具", tool_names,
format_func=lambda x: tools.get(x, {"title": "None"})["title"], format_func=lambda x: tools.get(x, {"title": "None"})["title"],
key="selected_tool") key="selected_tool")
@ -338,7 +340,7 @@ def dialogue_page(
elif d.status == AgentStatus.agent_finish: elif d.status == AgentStatus.agent_finish:
text = d.choices[0].delta.content or "" text = d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n")) chat_box.update_msg(text.replace("\n", "\n\n"))
elif d.status == None: # not agent chat elif d.status is None: # not agent chat
if getattr(d, "is_ref", False): if getattr(d, "is_ref", False):
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete", chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
title="参考资料")) title="参考资料"))

View File

@ -10,16 +10,18 @@ packages = [
[tool.poetry.scripts] [tool.poetry.scripts]
chatchat = 'chatchat.startup:main' chatchat = 'chatchat.startup:main'
chatchat-kb = 'chatchat.init_database:main'
chatchat-config = 'chatchat.config_work_space:main'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<3.12,!=3.9.7" python = ">=3.8.1,<3.12,!=3.9.7"
model-providers = "^0.3.0" model-providers = "^0.3.0"
langchain = "0.1.5" langchain = "0.1.17"
langchainhub = "0.1.14" langchainhub = "0.1.14"
langchain-community = "0.0.17" langchain-community = "0.0.36"
langchain-openai = "0.0.5" langchain-openai = "0.0.5"
langchain-experimental = "0.0.50" langchain-experimental = "0.0.58"
fastapi = "0.109.2" fastapi = "~0.109.2"
sse_starlette = "~1.8.2" sse_starlette = "~1.8.2"
nltk = "~3.8.1" nltk = "~3.8.1"
uvicorn = ">=0.27.0.post1" uvicorn = ">=0.27.0.post1"
@ -27,6 +29,8 @@ unstructured = "~0.11.0"
python-magic-bin = {version = "*", platform = "win32"} python-magic-bin = {version = "*", platform = "win32"}
SQLAlchemy = "~2.0.25" SQLAlchemy = "~2.0.25"
faiss-cpu = "~1.7.4" faiss-cpu = "~1.7.4"
cutword = "0.1.0"
rank_bm25 = "0.2.2"
# accelerate = "~0.24.1" # accelerate = "~0.24.1"
# spacy = "~3.7.2" # spacy = "~3.7.2"
PyMuPDF = "~1.23.16" PyMuPDF = "~1.23.16"
@ -34,28 +38,29 @@ rapidocr_onnxruntime = "~1.3.8"
requests = "~2.31.0" requests = "~2.31.0"
pathlib = "~1.0.1" pathlib = "~1.0.1"
pytest = "~7.4.3" pytest = "~7.4.3"
pyjwt = "2.8.0" pyjwt = "~2.8.0"
elasticsearch = "*" elasticsearch = "*"
numexpr = ">=1" #test numexpr = ">=1" #test
strsimpy = ">=0.2.1" strsimpy = ">=0.2.1"
markdownify = ">=0.11.6" markdownify = ">=0.11.6"
tqdm = ">=4.66.1" tqdm = ">=4.66.1"
websockets = ">=12.0" websockets = ">=12.0"
numpy = "1.24.4" numpy = "~1.24.4"
pandas = "~1" # test pandas = "~1" # test
pydantic = "2.6.4" pydantic = "~2.6.4"
httpx = {version = ">=0.25.2", extras = ["brotli", "http2", "socks"]} httpx = {version = ">=0.25.2", extras = ["brotli", "http2", "socks"]}
python-multipart = "0.0.9" python-multipart = "0.0.9"
# webui # webui
streamlit = "1.34.0" streamlit = "1.34.0"
streamlit-option-menu = "0.3.12" streamlit-option-menu = "0.3.12"
streamlit-antd-components = "0.3.1" streamlit-antd-components = "0.3.1"
streamlit-chatbox = "1.1.12" streamlit-chatbox = "1.1.12.post2"
streamlit-modal = "0.1.0" streamlit-modal = "0.1.0"
streamlit-aggrid = "0.3.4.post3" streamlit-aggrid = "0.3.4.post3"
streamlit-extras = "0.4.2" streamlit-extras = "0.4.2"
xinference_client = { version = "^0.11.1", optional = true } xinference_client = { version = "^0.11.1", optional = true }
zhipuai = { version = "^2.1.0", optional = true } zhipuai = { version = "^2.1.0", optional = true }
pymysql = "^1.1.0"
[tool.poetry.extras] [tool.poetry.extras]
xinference = ["xinference_client"] xinference = ["xinference_client"]
@ -254,7 +259,7 @@ ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
[tool.poetry.plugins.dotenv] [tool.poetry.plugins.dotenv]
ignore = "false" ignore = "false"
location = ".env" dotenv = "dotenv:plugin"
# https://python-poetry.org/docs/repositories/ # https://python-poetry.org/docs/repositories/

View File

@ -0,0 +1,90 @@
"""Configuration for unit tests."""
import logging
from importlib import util
from typing import Dict, List, Sequence
import pytest
from pytest import Config, Function, Parser
def pytest_addoption(parser: Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
)
break
else:
if only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)

View File

@ -0,0 +1,56 @@
from pathlib import Path
from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace
import os
def test_config_factory_def():
test_config_factory = ConfigBasicFactory()
config: ConfigBasic = test_config_factory.get_config()
assert config is not None
assert config.log_verbose is False
assert config.CHATCHAT_ROOT is not None
assert config.DATA_PATH is not None
assert config.IMG_DIR is not None
assert config.NLTK_DATA_PATH is not None
assert config.LOG_FORMAT is not None
assert config.LOG_PATH is not None
assert config.MEDIA_PATH is not None
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
def test_workspace():
config_workspace = ConfigWorkSpace()
assert config_workspace.get_config() is not None
base_root = os.path.join(Path(__file__).absolute().parent, "chatchat")
config_workspace.set_data_path(os.path.join(base_root, "data"))
config_workspace.set_log_verbose(True)
config_workspace.set_log_format(" %(message)s")
config: ConfigBasic = config_workspace.get_config()
assert config.log_verbose is True
assert config.DATA_PATH == os.path.join(base_root, "data")
assert config.IMG_DIR is not None
assert config.NLTK_DATA_PATH == os.path.join(base_root, "data", "nltk_data")
assert config.LOG_FORMAT == " %(message)s"
assert config.LOG_PATH == os.path.join(base_root, "data", "logs")
assert config.MEDIA_PATH == os.path.join(base_root, "data", "media")
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
config_workspace.clear()
def test_workspace_default():
from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH)
assert log_verbose is False
assert DATA_PATH is not None
assert IMG_DIR is not None
assert NLTK_DATA_PATH is not None
assert LOG_FORMAT is not None
assert LOG_PATH is not None
assert MEDIA_PATH is not None

View File

@ -0,0 +1 @@
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring

View File

@ -1,26 +1,26 @@
import redis # import redis
from redis.connection import Connection, SSLConnection # from redis.connection import Connection, SSLConnection
#
redis_client = redis.Redis() # redis_client = redis.Redis()
#
#
def init_app(app): # def init_app(app):
connection_class = Connection # connection_class = Connection
if app.config.get("REDIS_USE_SSL", False): # if app.config.get("REDIS_USE_SSL", False):
connection_class = SSLConnection # connection_class = SSLConnection
#
redis_client.connection_pool = redis.ConnectionPool( # redis_client.connection_pool = redis.ConnectionPool(
**{ # **{
"host": app.config.get("REDIS_HOST", "localhost"), # "host": app.config.get("REDIS_HOST", "localhost"),
"port": app.config.get("REDIS_PORT", 6379), # "port": app.config.get("REDIS_PORT", 6379),
"username": app.config.get("REDIS_USERNAME", None), # "username": app.config.get("REDIS_USERNAME", None),
"password": app.config.get("REDIS_PASSWORD", None), # "password": app.config.get("REDIS_PASSWORD", None),
"db": app.config.get("REDIS_DB", 0), # "db": app.config.get("REDIS_DB", 0),
"encoding": "utf-8", # "encoding": "utf-8",
"encoding_errors": "strict", # "encoding_errors": "strict",
"decode_responses": False, # "decode_responses": False,
}, # },
connection_class=connection_class, # connection_class=connection_class,
) # )
#
app.extensions["redis"] = redis_client # app.extensions["redis"] = redis_client

View File

@ -7,20 +7,19 @@ readme = "README.md"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<3.12,!=3.9.7" python = ">=3.8.1,<3.12,!=3.9.7"
transformers = "4.31.0" transformers = "~4.31.0"
fastapi = "^0.109.2" fastapi = "~0.109.2"
uvicorn = ">=0.27.0.post1" uvicorn = ">=0.27.0.post1"
sse-starlette = "^1.8.2" sse-starlette = "~1.8.2"
pyyaml = "6.0.1" pyyaml = "~6.0.1"
pydantic = "2.6.4" pydantic ="~2.6.4"
redis = "4.5.4"
# config manage # config manage
omegaconf = "2.0.6" omegaconf = "~2.0.6"
# modle_runtime # modle_runtime
openai = "1.13.3" openai = "~1.13.3"
tiktoken = "0.5.2" tiktoken = "~0.5.2"
pydub = "0.25.1" pydub = "~0.25.1"
boto3 = "1.28.17" boto3 = "~1.28.17"
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
# The only dependencies that should be added are # The only dependencies that should be added are
@ -207,7 +206,7 @@ ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
[tool.poetry.plugins.dotenv] [tool.poetry.plugins.dotenv]
ignore = "false" ignore = "false"
location = ".env" dotenv = "dotenv:plugin"
# https://python-poetry.org/docs/repositories/ # https://python-poetry.org/docs/repositories/
[[tool.poetry.source]] [[tool.poetry.source]]

View File

@ -25,7 +25,19 @@ langchain-chatchat = { path = "libs/chatchat-server", develop = true }
ipykernel = "^6.29.2" ipykernel = "^6.29.2"
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.9.2"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.23.2"
lark = "^1.1.5"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
requests-mock = "^1.11.0"
[tool.ruff] [tool.ruff]
extend-include = ["*.ipynb"] extend-include = ["*.ipynb"]
@ -48,7 +60,7 @@ extend-exclude = [
[tool.poetry.plugins.dotenv] [tool.poetry.plugins.dotenv]
ignore = "false" ignore = "false"
location = ".env" dotenv = "dotenv:plugin"
# https://python-poetry.org/docs/repositories/ # https://python-poetry.org/docs/repositories/

View File

@ -133,6 +133,8 @@ model_names = list(regs[model_type].keys())
model_name = cols[1].selectbox("模型名称:", model_names) model_name = cols[1].selectbox("模型名称:", model_names)
cur_reg = regs[model_type][model_name]["reg"] cur_reg = regs[model_type][model_name]["reg"]
model_format = None
model_quant = None
if model_type == "LLM": if model_type == "LLM":
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg) cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
@ -217,7 +219,7 @@ cols = st.columns(3)
if cols[0].button("设置模型缓存"): if cols[0].button("设置模型缓存"):
if os.path.isabs(local_path) and os.path.isdir(local_path): if os.path.isabs(local_path) and os.path.isdir(local_path):
cur_spec.model_uri = local_path cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri
if os.path.isdir(cache_dir): if os.path.isdir(cache_dir):
os.rmdir(cache_dir) os.rmdir(cache_dir)
if model_type == "LLM": if model_type == "LLM":
@ -250,10 +252,10 @@ if cols[2].button("注册为自定义模型"):
if model_type == "LLM": if model_type == "LLM":
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}" cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
cur_family.model_family = "other" cur_family.model_family = "other"
model_definition = cur_family.json(indent=2, ensure_ascii=False) model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False)
else: else:
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}" cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
model_definition = cur_spec.json(indent=2, ensure_ascii=False) model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False)
client.register_model( client.register_model(
model_type=model_type, model_type=model_type,
model=model_definition, model=model_definition,
@ -262,4 +264,3 @@ if cols[2].button("注册为自定义模型"):
st.rerun() st.rerun()
else: else:
st.error("必须输入存在的绝对路径") st.error("必须输入存在的绝对路径")