mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-07 07:23:29 +08:00
Merge branch 'chatchat-space:dev' into dev
This commit is contained in:
commit
d189bd182f
@ -15,9 +15,11 @@
|
|||||||
|
|
||||||
Install Poetry: [documentation on how to install it.](https://python-poetry.org/docs/#installing-with-pipx)
|
Install Poetry: [documentation on how to install it.](https://python-poetry.org/docs/#installing-with-pipx)
|
||||||
|
|
||||||
|
> 友情提示 不想安装pipx可以用pip安装poetry,(Tips:如果你没有其它poetry的项目
|
||||||
> 注意: 如果您使用 Conda 或 Pyenv 作为您的环境/包管理器,在安装Poetry之后,
|
> 注意: 如果您使用 Conda 或 Pyenv 作为您的环境/包管理器,在安装Poetry之后,
|
||||||
> 使用如下命令使 Poetry 使用 virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`)
|
> 使用如下命令使 Poetry 使用 virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`)
|
||||||
|
|
||||||
|
|
||||||
#### 本地开发环境安装
|
#### 本地开发环境安装
|
||||||
|
|
||||||
- 选择主项目目录
|
- 选择主项目目录
|
||||||
|
|||||||
1
libs/chatchat-server/.env
Normal file
1
libs/chatchat-server/.env
Normal file
@ -0,0 +1 @@
|
|||||||
|
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
|
||||||
@ -29,6 +29,11 @@ vim model_providers.yaml
|
|||||||
>
|
>
|
||||||
> 详细配置请参考[README.md](../model-providers/README.md)
|
> 详细配置请参考[README.md](../model-providers/README.md)
|
||||||
|
|
||||||
|
- 初始化知识库
|
||||||
|
```shell
|
||||||
|
chatchat-kb -r
|
||||||
|
```
|
||||||
|
|
||||||
- 启动服务
|
- 启动服务
|
||||||
```shell
|
```shell
|
||||||
chatchat -a
|
chatchat -a
|
||||||
|
|||||||
47
libs/chatchat-server/chatchat/config_work_space.py
Normal file
47
libs/chatchat-server/chatchat/config_work_space.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from chatchat.configs import config_workspace as workspace
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="指令` chatchat-config` 工作空间配置")
|
||||||
|
# 只能选择true或false
|
||||||
|
parser.add_argument(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
choices=["true", "false"],
|
||||||
|
help="是否开启详细日志"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-d",
|
||||||
|
"--data",
|
||||||
|
help="数据存放路径"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--format",
|
||||||
|
help="日志格式"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clear",
|
||||||
|
action="store_true",
|
||||||
|
help="清除配置"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
if args.verbose.lower() == "true":
|
||||||
|
workspace.set_log_verbose(True)
|
||||||
|
else:
|
||||||
|
workspace.set_log_verbose(False)
|
||||||
|
if args.data:
|
||||||
|
workspace.set_data_path(args.data)
|
||||||
|
if args.format:
|
||||||
|
workspace.set_log_format(args.format)
|
||||||
|
if args.clear:
|
||||||
|
workspace.clear()
|
||||||
|
print(workspace.get_config())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,8 +1,10 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@ -38,7 +40,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
|
|||||||
)
|
)
|
||||||
user_import = False
|
user_import = False
|
||||||
if user_import:
|
if user_import:
|
||||||
|
|
||||||
# Dynamic loading {config}.py file
|
# Dynamic loading {config}.py file
|
||||||
py_path = os.path.join(user_config_path, import_config_mod + ".py")
|
py_path = os.path.join(user_config_path, import_config_mod + ".py")
|
||||||
spec = importlib.util.spec_from_file_location(
|
spec = importlib.util.spec_from_file_location(
|
||||||
@ -69,7 +70,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
|
|||||||
)
|
)
|
||||||
raise RuntimeError(f"Failed to load user config from {user_config_path}")
|
raise RuntimeError(f"Failed to load user config from {user_config_path}")
|
||||||
# 当前文件路径
|
# 当前文件路径
|
||||||
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
|
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(f"*",
|
spec = importlib.util.spec_from_file_location(f"*",
|
||||||
py_path)
|
py_path)
|
||||||
@ -95,75 +96,108 @@ CONFIG_IMPORTS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ConfigBasic() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
ConfigBasic = load_mod(basic_config_load.get("module"), "ConfigBasic")
|
||||||
|
|
||||||
|
return ConfigBasic
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ConfigBasicFactory() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
ConfigBasicFactory = load_mod(basic_config_load.get("module"), "ConfigBasicFactory")
|
||||||
|
|
||||||
|
return ConfigBasicFactory
|
||||||
|
|
||||||
|
|
||||||
|
def _import_ConfigWorkSpace() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace")
|
||||||
|
|
||||||
|
return ConfigWorkSpace
|
||||||
|
|
||||||
|
|
||||||
|
def _import_config_workspace() -> Any:
|
||||||
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
|
load_mod = basic_config_load.get("load_mod")
|
||||||
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
return config_workspace
|
||||||
|
|
||||||
def _import_log_verbose() -> Any:
|
def _import_log_verbose() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
log_verbose = load_mod(basic_config_load.get("module"), "log_verbose")
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
return config_workspace.get_config().log_verbose
|
||||||
return log_verbose
|
|
||||||
|
|
||||||
|
|
||||||
def _import_chatchat_root() -> Any:
|
def _import_chatchat_root() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
CHATCHAT_ROOT = load_mod(basic_config_load.get("module"), "CHATCHAT_ROOT")
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
return CHATCHAT_ROOT
|
return config_workspace.get_config().CHATCHAT_ROOT
|
||||||
|
|
||||||
|
|
||||||
def _import_data_path() -> Any:
|
def _import_data_path() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
DATA_PATH = load_mod(basic_config_load.get("module"), "DATA_PATH")
|
|
||||||
|
|
||||||
return DATA_PATH
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
return config_workspace.get_config().DATA_PATH
|
||||||
|
|
||||||
|
|
||||||
def _import_img_dir() -> Any:
|
def _import_img_dir() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
IMG_DIR = load_mod(basic_config_load.get("module"), "IMG_DIR")
|
|
||||||
|
|
||||||
return IMG_DIR
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
|
return config_workspace.get_config().IMG_DIR
|
||||||
|
|
||||||
|
|
||||||
def _import_nltk_data_path() -> Any:
|
def _import_nltk_data_path() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
NLTK_DATA_PATH = load_mod(basic_config_load.get("module"), "NLTK_DATA_PATH")
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
return NLTK_DATA_PATH
|
return config_workspace.get_config().NLTK_DATA_PATH
|
||||||
|
|
||||||
|
|
||||||
def _import_log_format() -> Any:
|
def _import_log_format() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
LOG_FORMAT = load_mod(basic_config_load.get("module"), "LOG_FORMAT")
|
|
||||||
|
|
||||||
return LOG_FORMAT
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
|
return config_workspace.get_config().LOG_FORMAT
|
||||||
|
|
||||||
|
|
||||||
def _import_log_path() -> Any:
|
def _import_log_path() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
LOG_PATH = load_mod(basic_config_load.get("module"), "LOG_PATH")
|
|
||||||
|
|
||||||
return LOG_PATH
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
|
return config_workspace.get_config().LOG_PATH
|
||||||
|
|
||||||
|
|
||||||
def _import_media_path() -> Any:
|
def _import_media_path() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
MEDIA_PATH = load_mod(basic_config_load.get("module"), "MEDIA_PATH")
|
|
||||||
|
|
||||||
return MEDIA_PATH
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
|
||||||
|
return config_workspace.get_config().MEDIA_PATH
|
||||||
|
|
||||||
|
|
||||||
def _import_base_temp_dir() -> Any:
|
def _import_base_temp_dir() -> Any:
|
||||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||||
load_mod = basic_config_load.get("load_mod")
|
load_mod = basic_config_load.get("load_mod")
|
||||||
BASE_TEMP_DIR = load_mod(basic_config_load.get("module"), "BASE_TEMP_DIR")
|
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||||
|
return config_workspace.get_config().BASE_TEMP_DIR
|
||||||
return BASE_TEMP_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _import_default_knowledge_base() -> Any:
|
def _import_default_knowledge_base() -> Any:
|
||||||
@ -285,6 +319,7 @@ def _import_db_root_path() -> Any:
|
|||||||
|
|
||||||
return DB_ROOT_PATH
|
return DB_ROOT_PATH
|
||||||
|
|
||||||
|
|
||||||
def _import_sqlalchemy_database_uri() -> Any:
|
def _import_sqlalchemy_database_uri() -> Any:
|
||||||
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
|
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
|
||||||
load_mod = kb_config_load.get("load_mod")
|
load_mod = kb_config_load.get("load_mod")
|
||||||
@ -478,7 +513,15 @@ def _import_api_server() -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> Any:
|
def __getattr__(name: str) -> Any:
|
||||||
if name == "log_verbose":
|
if name == "ConfigBasic":
|
||||||
|
return _import_ConfigBasic()
|
||||||
|
elif name == "ConfigBasicFactory":
|
||||||
|
return _import_ConfigBasicFactory()
|
||||||
|
elif name == "ConfigWorkSpace":
|
||||||
|
return _import_ConfigWorkSpace()
|
||||||
|
elif name == "config_workspace":
|
||||||
|
return _import_config_workspace()
|
||||||
|
elif name == "log_verbose":
|
||||||
return _import_log_verbose()
|
return _import_log_verbose()
|
||||||
elif name == "CHATCHAT_ROOT":
|
elif name == "CHATCHAT_ROOT":
|
||||||
return _import_chatchat_root()
|
return _import_chatchat_root()
|
||||||
@ -578,6 +621,7 @@ VERSION = "v0.3.0-preview"
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"VERSION",
|
"VERSION",
|
||||||
|
"config_workspace",
|
||||||
"log_verbose",
|
"log_verbose",
|
||||||
"CHATCHAT_ROOT",
|
"CHATCHAT_ROOT",
|
||||||
"DATA_PATH",
|
"DATA_PATH",
|
||||||
@ -626,5 +670,8 @@ __all__ = [
|
|||||||
"WEBUI_SERVER",
|
"WEBUI_SERVER",
|
||||||
"API_SERVER",
|
"API_SERVER",
|
||||||
|
|
||||||
|
"ConfigBasic",
|
||||||
|
"ConfigBasicFactory",
|
||||||
|
"ConfigWorkSpace",
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,55 +1,180 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
|
||||||
import langchain
|
|
||||||
|
|
||||||
|
|
||||||
# 是否显示详细日志
|
|
||||||
log_verbose = False
|
|
||||||
langchain.verbose = False
|
|
||||||
|
|
||||||
# 通常情况下不需要更改以下内容
|
|
||||||
|
|
||||||
# chatchat 项目根目录
|
|
||||||
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
|
||||||
|
|
||||||
# 用户数据根目录
|
|
||||||
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
|
|
||||||
if not os.path.exists(DATA_PATH):
|
|
||||||
os.mkdir(DATA_PATH)
|
|
||||||
|
|
||||||
# 项目相关图片
|
|
||||||
IMG_DIR = os.path.join(CHATCHAT_ROOT, "img")
|
|
||||||
if not os.path.exists(IMG_DIR):
|
|
||||||
os.mkdir(IMG_DIR)
|
|
||||||
|
|
||||||
# nltk 模型存储路径
|
|
||||||
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
|
|
||||||
import nltk
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
||||||
|
|
||||||
# 日志格式
|
|
||||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
logging.basicConfig(format=LOG_FORMAT)
|
|
||||||
|
|
||||||
|
|
||||||
# 日志存储路径
|
class ConfigBasic:
|
||||||
LOG_PATH = os.path.join(DATA_PATH, "logs")
|
log_verbose: bool
|
||||||
if not os.path.exists(LOG_PATH):
|
"""是否开启日志详细信息"""
|
||||||
os.mkdir(LOG_PATH)
|
CHATCHAT_ROOT: str
|
||||||
|
"""项目根目录"""
|
||||||
|
DATA_PATH: str
|
||||||
|
"""用户数据根目录"""
|
||||||
|
IMG_DIR: str
|
||||||
|
"""项目相关图片"""
|
||||||
|
NLTK_DATA_PATH: str
|
||||||
|
"""nltk 模型存储路径"""
|
||||||
|
LOG_FORMAT: str
|
||||||
|
"""日志格式"""
|
||||||
|
LOG_PATH: str
|
||||||
|
"""日志存储路径"""
|
||||||
|
MEDIA_PATH: str
|
||||||
|
"""模型生成内容(图片、视频、音频等)保存位置"""
|
||||||
|
BASE_TEMP_DIR: str
|
||||||
|
"""临时文件目录,主要用于文件对话"""
|
||||||
|
|
||||||
# 模型生成内容(图片、视频、音频等)保存位置
|
def __str__(self):
|
||||||
MEDIA_PATH = os.path.join(DATA_PATH, "media")
|
return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})"
|
||||||
if not os.path.exists(MEDIA_PATH):
|
|
||||||
os.mkdir(MEDIA_PATH)
|
|
||||||
os.mkdir(os.path.join(MEDIA_PATH, "image"))
|
|
||||||
os.mkdir(os.path.join(MEDIA_PATH, "audio"))
|
|
||||||
os.mkdir(os.path.join(MEDIA_PATH, "video"))
|
|
||||||
|
|
||||||
# 临时文件目录,主要用于文件对话
|
|
||||||
BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
|
class ConfigBasicFactory:
|
||||||
if not os.path.exists(BASE_TEMP_DIR):
|
"""Basic config for ChatChat """
|
||||||
os.mkdir(BASE_TEMP_DIR)
|
|
||||||
|
def __init__(self):
|
||||||
|
# 日志格式
|
||||||
|
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||||
|
logging.basicConfig(format=self.LOG_FORMAT)
|
||||||
|
self.LOG_VERBOSE = False
|
||||||
|
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
||||||
|
# 用户数据根目录
|
||||||
|
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||||
|
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||||
|
if not os.path.exists(self._DATA_PATH):
|
||||||
|
os.makedirs(self._DATA_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
self._init_data_dir()
|
||||||
|
|
||||||
|
# 项目相关图片
|
||||||
|
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
|
||||||
|
if not os.path.exists(self.IMG_DIR):
|
||||||
|
os.makedirs(self.IMG_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
def log_verbose(self, verbose: bool):
|
||||||
|
self.LOG_VERBOSE = verbose
|
||||||
|
|
||||||
|
def data_path(self, path: str):
|
||||||
|
self.DATA_PATH = path
|
||||||
|
if not os.path.exists(self.DATA_PATH):
|
||||||
|
os.makedirs(self.DATA_PATH, exist_ok=True)
|
||||||
|
# 复制_DATA_PATH数据到DATA_PATH
|
||||||
|
if self._DATA_PATH != self.DATA_PATH:
|
||||||
|
os.system(f"cp -r {self._DATA_PATH}/* {self.DATA_PATH}")
|
||||||
|
|
||||||
|
self._init_data_dir()
|
||||||
|
|
||||||
|
def log_format(self, log_format: str):
|
||||||
|
self.LOG_FORMAT = log_format
|
||||||
|
logging.basicConfig(format=self.LOG_FORMAT)
|
||||||
|
|
||||||
|
def _init_data_dir(self):
|
||||||
|
logger.info(f"Init data dir: {self.DATA_PATH}")
|
||||||
|
# nltk 模型存储路径
|
||||||
|
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
|
||||||
|
import nltk
|
||||||
|
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
# 日志存储路径
|
||||||
|
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
|
||||||
|
if not os.path.exists(self.LOG_PATH):
|
||||||
|
os.makedirs(self.LOG_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
# 模型生成内容(图片、视频、音频等)保存位置
|
||||||
|
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
|
||||||
|
if not os.path.exists(self.MEDIA_PATH):
|
||||||
|
os.makedirs(self.MEDIA_PATH, exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(self.MEDIA_PATH, "image"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(self.MEDIA_PATH, "audio"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(self.MEDIA_PATH, "video"), exist_ok=True)
|
||||||
|
|
||||||
|
# 临时文件目录,主要用于文件对话
|
||||||
|
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
|
||||||
|
if not os.path.exists(self.BASE_TEMP_DIR):
|
||||||
|
os.makedirs(self.BASE_TEMP_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info(f"Init data dir: {self.DATA_PATH} success.")
|
||||||
|
|
||||||
|
def get_config(self) -> ConfigBasic:
|
||||||
|
config = ConfigBasic()
|
||||||
|
config.log_verbose = self.LOG_VERBOSE
|
||||||
|
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
|
||||||
|
config.DATA_PATH = self.DATA_PATH
|
||||||
|
config.IMG_DIR = self.IMG_DIR
|
||||||
|
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
|
||||||
|
config.LOG_FORMAT = self.LOG_FORMAT
|
||||||
|
config.LOG_PATH = self.LOG_PATH
|
||||||
|
config.MEDIA_PATH = self.MEDIA_PATH
|
||||||
|
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigWorkSpace:
|
||||||
|
"""
|
||||||
|
工作空间的配置预设,提供ConfigBasic建造方法产生实例。
|
||||||
|
该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
|
||||||
|
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。
|
||||||
|
注意:不存在则读取默认
|
||||||
|
"""
|
||||||
|
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
|
||||||
|
if not os.path.exists(self.workspace):
|
||||||
|
os.makedirs(self.workspace, exist_ok=True)
|
||||||
|
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
|
||||||
|
# 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化
|
||||||
|
|
||||||
|
config_json = self._load_config()
|
||||||
|
|
||||||
|
if config_json:
|
||||||
|
|
||||||
|
_config_factory = ConfigBasicFactory()
|
||||||
|
if config_json.get("log_verbose"):
|
||||||
|
_config_factory.log_verbose(config_json.get("log_verbose"))
|
||||||
|
if config_json.get("DATA_PATH"):
|
||||||
|
_config_factory.data_path(config_json.get("DATA_PATH"))
|
||||||
|
if config_json.get("LOG_FORMAT"):
|
||||||
|
_config_factory.log_format(config_json.get("LOG_FORMAT"))
|
||||||
|
|
||||||
|
self._config_factory = _config_factory
|
||||||
|
|
||||||
|
def get_config(self) -> ConfigBasic:
|
||||||
|
return self._config_factory.get_config()
|
||||||
|
|
||||||
|
def set_log_verbose(self, verbose: bool):
|
||||||
|
self._config_factory.log_verbose(verbose)
|
||||||
|
self._store_config()
|
||||||
|
|
||||||
|
def set_data_path(self, path: str):
|
||||||
|
self._config_factory.data_path(path)
|
||||||
|
self._store_config()
|
||||||
|
|
||||||
|
def set_log_format(self, log_format: str):
|
||||||
|
self._config_factory.log_format(log_format)
|
||||||
|
self._store_config()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
logger.info("Clear workspace config.")
|
||||||
|
os.remove(self.workspace_config)
|
||||||
|
|
||||||
|
def _load_config(self):
|
||||||
|
try:
|
||||||
|
with open(self.workspace_config, "r") as f:
|
||||||
|
return json.loads(f.read())
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _store_config(self):
|
||||||
|
with open(self.workspace_config, "w") as f:
|
||||||
|
config = self._config_factory.get_config()
|
||||||
|
config_json = {
|
||||||
|
"log_verbose": config.log_verbose,
|
||||||
|
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
|
||||||
|
"DATA_PATH": config.DATA_PATH,
|
||||||
|
"LOG_FORMAT": config.LOG_FORMAT
|
||||||
|
}
|
||||||
|
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
config_workspace: ConfigWorkSpace = ConfigWorkSpace()
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# chatchat 项目根目录
|
import sys
|
||||||
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
sys.path.append(str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from _basic_config import config_workspace
|
||||||
|
|
||||||
# 用户数据根目录
|
# 用户数据根目录
|
||||||
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
|
DATA_PATH = config_workspace.get_config().DATA_PATH
|
||||||
|
|
||||||
# 默认使用的知识库
|
# 默认使用的知识库
|
||||||
DEFAULT_KNOWLEDGE_BASE = "samples"
|
DEFAULT_KNOWLEDGE_BASE = "samples"
|
||||||
|
|||||||
@ -118,6 +118,25 @@ MODEL_PLATFORMS = [
|
|||||||
"tts_models": [],
|
"tts_models": [],
|
||||||
},
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"platform_name": "xinference",
|
||||||
|
"platform_type": "xinference",
|
||||||
|
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||||
|
"api_key": "EMPTY",
|
||||||
|
"api_concurrencies": 5,
|
||||||
|
"llm_models": [
|
||||||
|
"glm-4",
|
||||||
|
"qwen2-instruct",
|
||||||
|
"qwen1.5-chat",
|
||||||
|
],
|
||||||
|
"embed_models": [
|
||||||
|
"bge-large-zh-v1.5",
|
||||||
|
],
|
||||||
|
"image_models": [],
|
||||||
|
"reranking_models": [],
|
||||||
|
"speech2text_models": [],
|
||||||
|
"tts_models": [],
|
||||||
|
},
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -130,7 +149,7 @@ TOOL_CONFIG = {
|
|||||||
"search_local_knowledgebase": {
|
"search_local_knowledgebase": {
|
||||||
"use": False,
|
"use": False,
|
||||||
"top_k": 3,
|
"top_k": 3,
|
||||||
"score_threshold": 1,
|
"score_threshold": 1.0,
|
||||||
"conclude_prompt": {
|
"conclude_prompt": {
|
||||||
"with_result":
|
"with_result":
|
||||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",'
|
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",'
|
||||||
@ -208,5 +227,34 @@ TOOL_CONFIG = {
|
|||||||
"text2images": {
|
"text2images": {
|
||||||
"use": False,
|
"use": False,
|
||||||
},
|
},
|
||||||
|
# text2sql使用建议
|
||||||
|
# 1、因大模型生成的sql可能与预期有偏差,请务必在测试环境中进行充分测试、评估;
|
||||||
|
# 2、生产环境中,对于查询操作,由于不确定查询效率,推荐数据库采用主从数据库架构,让text2sql连接从数据库,防止可能的慢查询影响主业务;
|
||||||
|
# 3、对于写操作应保持谨慎,如不需要写操作,设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作;
|
||||||
|
# 4、text2sql与大模型在意图理解、sql转换等方面的能力有关,可切换不同大模型进行测试;
|
||||||
|
# 5、数据库表名、字段名应与其实际作用保持一致、容易理解,且应对数据库表名、字段进行详细的备注说明,帮助大模型更好理解数据库结构;
|
||||||
|
# 6、若现有数据库表名难于让大模型理解,可配置下面table_comments字段,补充说明某些表的作用。
|
||||||
|
"text2sql": {
|
||||||
|
"use": False,
|
||||||
|
# SQLAlchemy连接字符串,支持的数据库有:
|
||||||
|
# crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb
|
||||||
|
# 不同的数据库请查询SQLAlchemy,修改sqlalchemy_connect_str,配置对应的数据库连接,如sqlite为sqlite:///数据库文件路径,下面示例为mysql
|
||||||
|
# 如提示缺少对应数据库的驱动,请自行通过poetry安装
|
||||||
|
"sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e",
|
||||||
|
# 务必评估是否需要开启read_only,开启后会对sql语句进行检查,请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求
|
||||||
|
# 优先推荐从数据库层面对用户权限进行限制
|
||||||
|
"read_only": False,
|
||||||
|
#限定返回的行数
|
||||||
|
"top_k":50,
|
||||||
|
#是否返回中间步骤
|
||||||
|
"return_intermediate_steps": True,
|
||||||
|
#如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表
|
||||||
|
"table_names":[],
|
||||||
|
#对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。
|
||||||
|
"table_comments":{
|
||||||
|
# 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明
|
||||||
|
# "tableA":"这是一个用户表,存储了用户的基本信息",
|
||||||
|
# "tanleB":"角色表",
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,16 +20,16 @@
|
|||||||
|
|
||||||
xinference:
|
xinference:
|
||||||
model_credential:
|
model_credential:
|
||||||
- model: 'chatglm3-6b'
|
- model: 'glm-4'
|
||||||
model_type: 'llm'
|
model_type: 'llm'
|
||||||
model_credentials:
|
model_credentials:
|
||||||
server_url: 'http://127.0.0.1:9997/'
|
server_url: 'http://127.0.0.1:9997/'
|
||||||
model_uid: 'chatglm3-6b'
|
model_uid: 'glm-4'
|
||||||
- model: 'Qwen1.5-14B-Chat'
|
- model: 'qwen1.5-chat'
|
||||||
model_type: 'llm'
|
model_type: 'llm'
|
||||||
model_credentials:
|
model_credentials:
|
||||||
server_url: 'http://127.0.0.1:9997/'
|
server_url: 'http://127.0.0.1:9997/'
|
||||||
model_uid: 'Qwen1.5-14B-Chat'
|
model_uid: 'qwen1.5-chat'
|
||||||
- model: 'bge-large-zh-v1.5'
|
- model: 'bge-large-zh-v1.5'
|
||||||
model_type: 'embeddings'
|
model_type: 'embeddings'
|
||||||
model_credentials:
|
model_credentials:
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
|
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
|
||||||
|
from datetime import datetime
|
||||||
|
import multiprocessing as mp
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||||
folder2db, prune_db_docs, prune_folder_files)
|
folder2db, prune_db_docs, prune_folder_files)
|
||||||
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS
|
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger
|
||||||
import multiprocessing as mp
|
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def run_init_model_provider(
|
def run_init_model_provider(
|
||||||
@ -34,7 +33,7 @@ def run_init_model_provider(
|
|||||||
provider_port=provider_port)
|
provider_port=provider_port)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
|
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
|
||||||
@ -186,3 +185,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for p in processes.values():
|
for p in processes.values():
|
||||||
logger.info("Process status: %s", p)
|
logger.info("Process status: %s", p)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@ -10,3 +10,4 @@ from .text2image import text2images
|
|||||||
|
|
||||||
from .vqa_processor import vqa_processor
|
from .vqa_processor import vqa_processor
|
||||||
from .aqa_processor import aqa_processor
|
from .aqa_processor import aqa_processor
|
||||||
|
from .text2sql import text2sql
|
||||||
@ -1,13 +1,14 @@
|
|||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from chatchat.server.utils import get_tool_config
|
from chatchat.server.utils import get_tool_config
|
||||||
from chatchat.server.pydantic_v1 import Field
|
from chatchat.server.pydantic_v1 import Field
|
||||||
from .tools_registry import regist_tool, BaseToolOutput
|
from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput
|
||||||
from chatchat.server.knowledge_base.kb_api import list_kbs
|
from chatchat.server.knowledge_base.kb_api import list_kbs
|
||||||
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
|
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
|
||||||
from chatchat.configs import KB_INFO
|
from chatchat.configs import KB_INFO
|
||||||
|
|
||||||
|
|
||||||
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on "
|
||||||
|
"this knowledge use this tool. The 'database' should be one of the above [{key}].")
|
||||||
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
||||||
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ def search_local_knowledgebase(
|
|||||||
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
|
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
|
||||||
query: str = Field(description="Query for Knowledge Search"),
|
query: str = Field(description="Query for Knowledge Search"),
|
||||||
):
|
):
|
||||||
''''''
|
""""""
|
||||||
tool_config = get_tool_config("search_local_knowledgebase")
|
tool_config = get_tool_config("search_local_knowledgebase")
|
||||||
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
||||||
return KBToolOutput(ret, database=database)
|
return KBToolOutput(ret, database=database)
|
||||||
|
|||||||
@ -0,0 +1,103 @@
|
|||||||
|
from langchain.utilities import SQLDatabase
|
||||||
|
from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialChain
|
||||||
|
from chatchat.server.utils import get_tool_config
|
||||||
|
from chatchat.server.pydantic_v1 import Field
|
||||||
|
from .tools_registry import regist_tool, BaseToolOutput
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
from sqlalchemy import event
|
||||||
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
|
||||||
|
READ_ONLY_PROMPT_TEMPLATE="""You are a MySQL expert. The database is currently in read-only mode.
|
||||||
|
Given an input question, determine if the related SQL can be executed in read-only mode.
|
||||||
|
If the SQL can be executed normally, return Answer:'SQL can be executed normally'.
|
||||||
|
If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'.
|
||||||
|
Use the following format:
|
||||||
|
|
||||||
|
Answer: Final answer here
|
||||||
|
|
||||||
|
Question: {query}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 定义一个拦截器函数来检查SQL语句,以支持read-only,可修改下面的write_operations,以匹配你使用的数据库写操作关键字
|
||||||
|
def intercept_sql(conn, cursor, statement, parameters, context, executemany):
|
||||||
|
# List of SQL keywords that indicate a write operation
|
||||||
|
write_operations = ("insert", "update", "delete", "create", "drop", "alter", "truncate", "rename")
|
||||||
|
# Check if the statement starts with any of the write operation keywords
|
||||||
|
if any(statement.strip().lower().startswith(op) for op in write_operations):
|
||||||
|
raise OperationalError("Database is read-only. Write operations are not allowed.", params=None, orig=None)
|
||||||
|
|
||||||
|
def query_database(query: str,
|
||||||
|
config: dict):
|
||||||
|
top_k = config["top_k"]
|
||||||
|
return_intermediate_steps = config["return_intermediate_steps"]
|
||||||
|
sqlalchemy_connect_str = config["sqlalchemy_connect_str"]
|
||||||
|
read_only = config["read_only"]
|
||||||
|
db = SQLDatabase.from_uri(sqlalchemy_connect_str)
|
||||||
|
|
||||||
|
from chatchat.server.api_server.chat_routes import global_model_name
|
||||||
|
from chatchat.server.utils import get_ChatOpenAI
|
||||||
|
llm = get_ChatOpenAI(
|
||||||
|
model_name=global_model_name,
|
||||||
|
temperature=0,
|
||||||
|
streaming=True,
|
||||||
|
local_wrap=True,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
table_names=config["table_names"]
|
||||||
|
table_comments=config["table_comments"]
|
||||||
|
result = None
|
||||||
|
|
||||||
|
#如果发现大模型判断用什么表出现问题,尝试给langchain提供额外的表说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判
|
||||||
|
#由于langchain固定了输入参数,所以只能通过query传递额外的表说明
|
||||||
|
if table_comments:
|
||||||
|
TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n"
|
||||||
|
table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])
|
||||||
|
query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n"
|
||||||
|
|
||||||
|
if read_only:
|
||||||
|
# 在read_only下,先让大模型判断只读模式是否能满足需求,避免后续执行过程报错,返回友好提示。
|
||||||
|
READ_ONLY_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["query"],
|
||||||
|
template=READ_ONLY_PROMPT_TEMPLATE,
|
||||||
|
)
|
||||||
|
read_only_chain = LLMChain(
|
||||||
|
prompt=READ_ONLY_PROMPT,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
read_only_result = read_only_chain.invoke(query)
|
||||||
|
if "SQL cannot be executed normally" in read_only_result["text"]:
|
||||||
|
return "当前数据库为只读状态,无法满足您的需求!"
|
||||||
|
|
||||||
|
# 当然大模型不能保证完全判断准确,为防止大模型判断有误,再从拦截器层面拒绝写操作
|
||||||
|
event.listen(db._engine, "before_cursor_execute", intercept_sql)
|
||||||
|
|
||||||
|
#如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain
|
||||||
|
#这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源
|
||||||
|
#如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断
|
||||||
|
if len(table_names) > 0:
|
||||||
|
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
|
||||||
|
result = db_chain.invoke({"query":query,"table_names_to_use":table_names})
|
||||||
|
else:
|
||||||
|
#先预测会使用哪些表,然后再将问题和预测的表给大模型
|
||||||
|
db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
|
||||||
|
result = db_chain.invoke(query)
|
||||||
|
|
||||||
|
context = f"""查询结果:{result['result']}\n\n"""
|
||||||
|
|
||||||
|
intermediate_steps=result["intermediate_steps"]
|
||||||
|
#如果存在intermediate_steps,且这个数组的长度大于2,则保留最后两个元素,因为前面几个步骤存在示例数据,容易引起误解
|
||||||
|
if intermediate_steps:
|
||||||
|
if len(intermediate_steps)>2:
|
||||||
|
sql_detail=intermediate_steps[-2:-1][0]["input"]
|
||||||
|
# sql_detail截取从SQLQuery到Answer:之间的内容
|
||||||
|
sql_detail=sql_detail[sql_detail.find("SQLQuery:")+9:sql_detail.find("Answer:")]
|
||||||
|
context = context+"执行的sql:'"+sql_detail+"'\n\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
@regist_tool(title="Text2Sql")
|
||||||
|
def text2sql(query: str = Field(description="No need for SQL statements,just input the natural language that you want to chat with database")):
|
||||||
|
'''Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result.'''
|
||||||
|
tool_config = get_tool_config("text2sql")
|
||||||
|
return BaseToolOutput(query_database(query=query, config=tool_config))
|
||||||
@ -28,6 +28,8 @@ chat_router.post("/file_chat",
|
|||||||
summary="文件对话"
|
summary="文件对话"
|
||||||
)(file_chat)
|
)(file_chat)
|
||||||
|
|
||||||
|
#定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name
|
||||||
|
global_model_name=None
|
||||||
|
|
||||||
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
|
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
|
||||||
async def chat_completions(
|
async def chat_completions(
|
||||||
@ -51,6 +53,8 @@ async def chat_completions(
|
|||||||
for key in list(extra):
|
for key in list(extra):
|
||||||
delattr(body, key)
|
delattr(body, key)
|
||||||
|
|
||||||
|
global global_model_name
|
||||||
|
global_model_name=body.model
|
||||||
# check tools & tool_choice in request body
|
# check tools & tool_choice in request body
|
||||||
if isinstance(body.tool_choice, str):
|
if isinstance(body.tool_choice, str):
|
||||||
if t := get_tool(body.tool_choice):
|
if t := get_tool(body.tool_choice):
|
||||||
|
|||||||
@ -63,10 +63,10 @@ def upload_temp_docs(
|
|||||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
'''
|
"""
|
||||||
将文件保存到临时目录,并进行向量化。
|
将文件保存到临时目录,并进行向量化。
|
||||||
返回临时目录名称作为ID,同时也是临时向量库的ID。
|
返回临时目录名称作为ID,同时也是临时向量库的ID。
|
||||||
'''
|
"""
|
||||||
if prev_id is not None:
|
if prev_id is not None:
|
||||||
memo_faiss_pool.pop(prev_id)
|
memo_faiss_pool.pop(prev_id)
|
||||||
|
|
||||||
@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
|||||||
docs = [x[0] for x in docs]
|
docs = [x[0] for x in docs]
|
||||||
|
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
|
if len(docs) == 0: # 如果没有找到相关文档,使用Empty模板
|
||||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||||
else:
|
else:
|
||||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
from chatchat.server.document_loaders.ocr import get_ocr
|
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
|
||||||
|
|
||||||
|
|
||||||
class RapidOCRLoader(UnstructuredFileLoader):
|
class RapidOCRLoader(UnstructuredFileLoader):
|
||||||
@ -4,7 +4,7 @@ import cv2
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from chatchat.configs import PDF_OCR_THRESHOLD
|
from chatchat.configs import PDF_OCR_THRESHOLD
|
||||||
from chatchat.server.document_loaders.ocr import get_ocr
|
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||||
|
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
|
||||||
|
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService
|
||||||
@ -0,0 +1,24 @@
|
|||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRetrieverService(metaclass=ABCMeta):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.do_init(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def do_init(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def from_vectorstore(
|
||||||
|
vectorstore: VectorStore,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: int or float,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_relevant_documents(self, query: str):
|
||||||
|
pass
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from langchain_community.retrievers import BM25Retriever
|
||||||
|
from langchain.retrievers import EnsembleRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class EnsembleRetrieverService(BaseRetrieverService):
|
||||||
|
def do_init(
|
||||||
|
self,
|
||||||
|
retriever: BaseRetriever = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
):
|
||||||
|
self.vs = None
|
||||||
|
self.top_k = top_k
|
||||||
|
self.retriever = retriever
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_vectorstore(
|
||||||
|
vectorstore: VectorStore,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: int or float,
|
||||||
|
):
|
||||||
|
faiss_retriever = vectorstore.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"k": top_k
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# TODO: 换个不用torch的实现方式
|
||||||
|
from cutword.cutword import Cutter
|
||||||
|
cutter = Cutter()
|
||||||
|
docs = list(vectorstore.docstore._dict.values())
|
||||||
|
bm25_retriever = BM25Retriever.from_documents(
|
||||||
|
docs,
|
||||||
|
preprocess_func=cutter.cutword
|
||||||
|
)
|
||||||
|
bm25_retriever.k = top_k
|
||||||
|
ensemble_retriever = EnsembleRetriever(
|
||||||
|
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
|
||||||
|
)
|
||||||
|
return EnsembleRetrieverService(retriever=ensemble_retriever)
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str):
|
||||||
|
return self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||||
|
from langchain.vectorstores import VectorStore
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class VectorstoreRetrieverService(BaseRetrieverService):
|
||||||
|
def do_init(
|
||||||
|
self,
|
||||||
|
retriever: BaseRetriever = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
):
|
||||||
|
self.vs = None
|
||||||
|
self.top_k = top_k
|
||||||
|
self.retriever = retriever
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_vectorstore(
|
||||||
|
vectorstore: VectorStore,
|
||||||
|
top_k: int,
|
||||||
|
score_threshold: int or float,
|
||||||
|
):
|
||||||
|
retriever = vectorstore.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"k": top_k
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return VectorstoreRetrieverService(retriever=retriever)
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str):
|
||||||
|
return self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||||
13
libs/chatchat-server/chatchat/server/file_rag/utils.py
Normal file
13
libs/chatchat-server/chatchat/server/file_rag/utils.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from chatchat.server.file_rag.retrievers import (
|
||||||
|
BaseRetrieverService,
|
||||||
|
VectorstoreRetrieverService,
|
||||||
|
EnsembleRetrieverService,
|
||||||
|
)
|
||||||
|
|
||||||
|
Retrivals = {
|
||||||
|
"vectorstore": VectorstoreRetrieverService,
|
||||||
|
"ensemble": EnsembleRetrieverService,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
|
||||||
|
return Retrivals[type]
|
||||||
@ -14,6 +14,7 @@ def list_kbs():
|
|||||||
|
|
||||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
vector_store_type: str = Body("faiss"),
|
vector_store_type: str = Body("faiss"),
|
||||||
|
kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"),
|
||||||
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
# Create selected knowledge base
|
# Create selected knowledge base
|
||||||
@ -26,7 +27,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||||||
if kb is not None:
|
if kb is not None:
|
||||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info)
|
||||||
try:
|
try:
|
||||||
kb.create_kb()
|
kb.create_kb()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -95,6 +95,7 @@ class KBFaissPool(_FaissPool):
|
|||||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
) -> ThreadSafeFaiss:
|
) -> ThreadSafeFaiss:
|
||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
|
locked = True
|
||||||
vector_name = vector_name or embed_model
|
vector_name = vector_name or embed_model
|
||||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||||
try:
|
try:
|
||||||
@ -103,12 +104,14 @@ class KBFaissPool(_FaissPool):
|
|||||||
self.set((kb_name, vector_name), item)
|
self.set((kb_name, vector_name), item)
|
||||||
with item.acquire(msg="初始化"):
|
with item.acquire(msg="初始化"):
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
|
locked = False
|
||||||
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
||||||
vs_path = get_vs_path(kb_name, vector_name)
|
vs_path = get_vs_path(kb_name, vector_name)
|
||||||
|
|
||||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||||
embeddings = get_Embeddings(embed_model=embed_model)
|
embeddings = get_Embeddings(embed_model=embed_model)
|
||||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,
|
||||||
|
allow_dangerous_deserialization=True)
|
||||||
elif create:
|
elif create:
|
||||||
# create an empty vector store
|
# create an empty vector store
|
||||||
if not os.path.exists(vs_path):
|
if not os.path.exists(vs_path):
|
||||||
@ -121,8 +124,10 @@ class KBFaissPool(_FaissPool):
|
|||||||
item.finish_loading()
|
item.finish_loading()
|
||||||
else:
|
else:
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
|
locked = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.atomic.release()
|
if locked: # we don't know exception raised before or after atomic.release
|
||||||
|
self.atomic.release()
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
|
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
|
||||||
return self.get((kb_name, vector_name))
|
return self.get((kb_name, vector_name))
|
||||||
|
|||||||
@ -1,21 +1,23 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
from fastapi import File, Form, Body, Query, UploadFile
|
from fastapi import File, Form, Body, Query, UploadFile
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
||||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||||
logger, log_verbose, )
|
logger, log_verbose, )
|
||||||
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
|
||||||
from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
||||||
files2docs_in_thread, KnowledgeFile)
|
files2docs_in_thread, KnowledgeFile)
|
||||||
from fastapi.responses import FileResponse
|
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
import json
|
|
||||||
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
|
|
||||||
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
from typing import List, Dict
|
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model
|
||||||
|
|
||||||
|
|
||||||
def search_docs(
|
def search_docs(
|
||||||
@ -35,7 +37,8 @@ def search_docs(
|
|||||||
if kb is not None:
|
if kb is not None:
|
||||||
if query:
|
if query:
|
||||||
docs = kb.search_docs(query, top_k, score_threshold)
|
docs = kb.search_docs(query, top_k, score_threshold)
|
||||||
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||||
|
data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs]
|
||||||
elif file_name or metadata:
|
elif file_name or metadata:
|
||||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||||
for d in data:
|
for d in data:
|
||||||
@ -55,8 +58,8 @@ def list_files(
|
|||||||
if kb is None:
|
if kb is None:
|
||||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||||
else:
|
else:
|
||||||
all_doc_names = kb.list_files()
|
all_docs = get_kb_file_details(knowledge_base_name)
|
||||||
return ListResponse(data=all_doc_names)
|
return ListResponse(data=all_docs)
|
||||||
|
|
||||||
|
|
||||||
def _save_files_in_thread(files: List[UploadFile],
|
def _save_files_in_thread(files: List[UploadFile],
|
||||||
@ -352,38 +355,42 @@ def recreate_vector_store(
|
|||||||
if not kb.exists() and not allow_empty_kb:
|
if not kb.exists() and not allow_empty_kb:
|
||||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||||
else:
|
else:
|
||||||
if kb.exists():
|
error_msg = f"could not recreate vector store because failed to access embed model."
|
||||||
kb.clear_vs()
|
if not kb.check_embed_model(error_msg):
|
||||||
kb.create_kb()
|
yield {"code": 404, "msg": error_msg}
|
||||||
files = list_files_from_folder(knowledge_base_name)
|
else:
|
||||||
kb_files = [(file, knowledge_base_name) for file in files]
|
if kb.exists():
|
||||||
i = 0
|
kb.clear_vs()
|
||||||
for status, result in files2docs_in_thread(kb_files,
|
kb.create_kb()
|
||||||
chunk_size=chunk_size,
|
files = list_files_from_folder(knowledge_base_name)
|
||||||
chunk_overlap=chunk_overlap,
|
kb_files = [(file, knowledge_base_name) for file in files]
|
||||||
zh_title_enhance=zh_title_enhance):
|
i = 0
|
||||||
if status:
|
for status, result in files2docs_in_thread(kb_files,
|
||||||
kb_name, file_name, docs = result
|
chunk_size=chunk_size,
|
||||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
chunk_overlap=chunk_overlap,
|
||||||
kb_file.splited_docs = docs
|
zh_title_enhance=zh_title_enhance):
|
||||||
yield json.dumps({
|
if status:
|
||||||
"code": 200,
|
kb_name, file_name, docs = result
|
||||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||||
"total": len(files),
|
kb_file.splited_docs = docs
|
||||||
"finished": i + 1,
|
yield json.dumps({
|
||||||
"doc": file_name,
|
"code": 200,
|
||||||
}, ensure_ascii=False)
|
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
"total": len(files),
|
||||||
else:
|
"finished": i + 1,
|
||||||
kb_name, file_name, error = result
|
"doc": file_name,
|
||||||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
}, ensure_ascii=False)
|
||||||
logger.error(msg)
|
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||||
yield json.dumps({
|
else:
|
||||||
"code": 500,
|
kb_name, file_name, error = result
|
||||||
"msg": msg,
|
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||||
})
|
logger.error(msg)
|
||||||
i += 1
|
yield json.dumps({
|
||||||
if not not_refresh_vs_cache:
|
"code": 500,
|
||||||
kb.save_vector_store()
|
"msg": msg,
|
||||||
|
})
|
||||||
|
i += 1
|
||||||
|
if not not_refresh_vs_cache:
|
||||||
|
kb.save_vector_store()
|
||||||
|
|
||||||
return EventSourceResponse(output())
|
return EventSourceResponse(output())
|
||||||
|
|||||||
@ -5,6 +5,11 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
from typing import List, Union, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
|
DEFAULT_EMBEDDING_MODEL, KB_INFO, logger)
|
||||||
|
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
|
||||||
from chatchat.server.db.repository.knowledge_base_repository import (
|
from chatchat.server.db.repository.knowledge_base_repository import (
|
||||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||||
load_kb_from_db, get_kb_detail,
|
load_kb_from_db, get_kb_detail,
|
||||||
@ -14,18 +19,12 @@ from chatchat.server.db.repository.knowledge_file_repository import (
|
|||||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
||||||
list_docs_from_db,
|
list_docs_from_db,
|
||||||
)
|
)
|
||||||
|
|
||||||
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
|
||||||
DEFAULT_EMBEDDING_MODEL, KB_INFO)
|
|
||||||
from chatchat.server.knowledge_base.utils import (
|
from chatchat.server.knowledge_base.utils import (
|
||||||
get_kb_path, get_doc_path, KnowledgeFile,
|
get_kb_path, get_doc_path, KnowledgeFile,
|
||||||
list_kbs_from_folder, list_files_from_folder,
|
list_kbs_from_folder, list_files_from_folder,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import List, Union, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
|
from chatchat.server.utils import check_embed_model as _check_embed_model
|
||||||
|
|
||||||
class SupportedVSType:
|
class SupportedVSType:
|
||||||
FAISS = 'faiss'
|
FAISS = 'faiss'
|
||||||
@ -41,10 +40,11 @@ class KBService(ABC):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
|
kb_info: str = None,
|
||||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
):
|
):
|
||||||
self.kb_name = knowledge_base_name
|
self.kb_name = knowledge_base_name
|
||||||
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||||
self.embed_model = embed_model
|
self.embed_model = embed_model
|
||||||
self.kb_path = get_kb_path(self.kb_name)
|
self.kb_path = get_kb_path(self.kb_name)
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
@ -59,6 +59,13 @@ class KBService(ABC):
|
|||||||
'''
|
'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def check_embed_model(self, error_msg: str) -> bool:
|
||||||
|
if not _check_embed_model(self.embed_model):
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
def create_kb(self):
|
def create_kb(self):
|
||||||
"""
|
"""
|
||||||
创建知识库
|
创建知识库
|
||||||
@ -93,6 +100,9 @@ class KBService(ABC):
|
|||||||
向知识库添加文件
|
向知识库添加文件
|
||||||
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
||||||
"""
|
"""
|
||||||
|
if not self.check_embed_model(f"could not add docs because failed to access embed model."):
|
||||||
|
return False
|
||||||
|
|
||||||
if docs:
|
if docs:
|
||||||
custom_docs = True
|
custom_docs = True
|
||||||
else:
|
else:
|
||||||
@ -143,6 +153,9 @@ class KBService(ABC):
|
|||||||
使用content中的文件更新向量库
|
使用content中的文件更新向量库
|
||||||
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
||||||
"""
|
"""
|
||||||
|
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
|
||||||
|
return False
|
||||||
|
|
||||||
if os.path.exists(kb_file.filepath):
|
if os.path.exists(kb_file.filepath):
|
||||||
self.delete_doc(kb_file, **kwargs)
|
self.delete_doc(kb_file, **kwargs)
|
||||||
return self.add_doc(kb_file, docs=docs, **kwargs)
|
return self.add_doc(kb_file, docs=docs, **kwargs)
|
||||||
@ -162,6 +175,8 @@ class KBService(ABC):
|
|||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
) ->List[Document]:
|
) ->List[Document]:
|
||||||
|
if not self.check_embed_model(f"could not search docs because failed to access embed model."):
|
||||||
|
return []
|
||||||
docs = self.do_search(query, top_k, score_threshold)
|
docs = self.do_search(query, top_k, score_threshold)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@ -176,6 +191,9 @@ class KBService(ABC):
|
|||||||
传入参数为: {doc_id: Document, ...}
|
传入参数为: {doc_id: Document, ...}
|
||||||
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
||||||
'''
|
'''
|
||||||
|
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
|
||||||
|
return False
|
||||||
|
|
||||||
self.del_doc_by_ids(list(docs.keys()))
|
self.del_doc_by_ids(list(docs.keys()))
|
||||||
docs = []
|
docs = []
|
||||||
ids = []
|
ids = []
|
||||||
@ -282,31 +300,32 @@ class KBServiceFactory:
|
|||||||
def get_service(kb_name: str,
|
def get_service(kb_name: str,
|
||||||
vector_store_type: Union[str, SupportedVSType],
|
vector_store_type: Union[str, SupportedVSType],
|
||||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
|
kb_info: str = None,
|
||||||
) -> KBService:
|
) -> KBService:
|
||||||
if isinstance(vector_store_type, str):
|
if isinstance(vector_store_type, str):
|
||||||
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||||
|
params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info}
|
||||||
if SupportedVSType.FAISS == vector_store_type:
|
if SupportedVSType.FAISS == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||||
return FaissKBService(kb_name, embed_model=embed_model)
|
return FaissKBService(**params)
|
||||||
elif SupportedVSType.PG == vector_store_type:
|
elif SupportedVSType.PG == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||||
return PGKBService(kb_name, embed_model=embed_model)
|
return PGKBService(**params)
|
||||||
elif SupportedVSType.MILVUS == vector_store_type:
|
elif SupportedVSType.MILVUS == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
return MilvusKBService(kb_name, embed_model=embed_model)
|
return MilvusKBService(**params)
|
||||||
elif SupportedVSType.ZILLIZ == vector_store_type:
|
elif SupportedVSType.ZILLIZ == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
|
from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
|
||||||
return ZillizKBService(kb_name, embed_model=embed_model)
|
return ZillizKBService(**params)
|
||||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
return MilvusKBService(kb_name,
|
return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config
|
||||||
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
|
||||||
elif SupportedVSType.ES == vector_store_type:
|
elif SupportedVSType.ES == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService
|
from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService
|
||||||
return ESKBService(kb_name, embed_model=embed_model)
|
return ESKBService(**params)
|
||||||
elif SupportedVSType.CHROMADB == vector_store_type:
|
elif SupportedVSType.CHROMADB == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
|
from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
|
||||||
return ChromaKBService(kb_name, embed_model=embed_model)
|
return ChromaKBService(**params)
|
||||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||||
from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||||
return DefaultKBService(kb_name)
|
return DefaultKBService(kb_name)
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD
|
|||||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.utils import get_Embeddings
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
|
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
|
||||||
@ -75,10 +76,13 @@ class ChromaKBService(KBService):
|
|||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
|
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
|
||||||
Tuple[Document, float]]:
|
Tuple[Document, float]]:
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
embeddings = embed_func.embed_query(query)
|
self.collection,
|
||||||
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
|
top_k=top_k,
|
||||||
return _results_to_docs_and_scores(query_result)
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
doc_infos = []
|
doc_infos = []
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile
|
|||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.utils import get_Embeddings
|
||||||
from elasticsearch import Elasticsearch, BadRequestError
|
from elasticsearch import Elasticsearch, BadRequestError
|
||||||
from chatchat.configs import kbs_config, KB_ROOT_PATH
|
from chatchat.configs import kbs_config, KB_ROOT_PATH
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -107,8 +108,12 @@ class ESKBService(KBService):
|
|||||||
|
|
||||||
def do_search(self, query:str, top_k: int, score_threshold: float):
|
def do_search(self, query:str, top_k: int, score_threshold: float):
|
||||||
# 文本相似性检索
|
# 文本相似性检索
|
||||||
docs = self.db.similarity_search_with_score(query=query,
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
k=top_k)
|
self.db,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
|
|||||||
@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD
|
|||||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||||
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||||
from chatchat.server.utils import get_Embeddings
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from typing import List, Dict, Optional, Tuple
|
from typing import List, Dict, Tuple
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
class FaissKBService(KBService):
|
class FaissKBService(KBService):
|
||||||
@ -62,10 +62,13 @@ class FaissKBService(KBService):
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
|
||||||
embeddings = embed_func.embed_query(query)
|
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
retriever = get_Retriever("ensemble").from_vectorstore(
|
||||||
|
vs,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f
|
|||||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
|
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
|
||||||
score_threshold_process
|
score_threshold_process
|
||||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
class MilvusKBService(KBService):
|
class MilvusKBService(KBService):
|
||||||
@ -67,10 +67,16 @@ class MilvusKBService(KBService):
|
|||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_milvus()
|
self._load_milvus()
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
# embed_func = get_Embeddings(self.embed_model)
|
||||||
embeddings = embed_func.embed_query(query)
|
# embeddings = embed_func.embed_query(query)
|
||||||
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
# docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||||
return score_threshold_process(score_threshold, top_k, docs)
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
|
self.milvus,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import shutil
|
|||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
class PGKBService(KBService):
|
class PGKBService(KBService):
|
||||||
@ -60,10 +61,13 @@ class PGKBService(KBService):
|
|||||||
shutil.rmtree(self.kb_path)
|
shutil.rmtree(self.kb_path)
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
embeddings = embed_func.embed_query(query)
|
self.pg_vector,
|
||||||
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
top_k=top_k,
|
||||||
return score_threshold_process(score_threshold, top_k, docs)
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
ids = self.pg_vector.add_documents(docs)
|
ids = self.pg_vector.add_documents(docs)
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from typing import List, Dict, Optional
|
from typing import List, Dict
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores import Zilliz
|
from langchain.vectorstores import Zilliz
|
||||||
from chatchat.configs import kbs_config
|
from chatchat.configs import kbs_config
|
||||||
@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
|
|||||||
score_threshold_process
|
score_threshold_process
|
||||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||||
from chatchat.server.utils import get_Embeddings
|
from chatchat.server.utils import get_Embeddings
|
||||||
|
from chatchat.server.file_rag.utils import get_Retriever
|
||||||
|
|
||||||
|
|
||||||
class ZillizKBService(KBService):
|
class ZillizKBService(KBService):
|
||||||
@ -60,10 +60,13 @@ class ZillizKBService(KBService):
|
|||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_zilliz()
|
self._load_zilliz()
|
||||||
embed_func = get_Embeddings(self.embed_model)
|
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||||
embeddings = embed_func.embed_query(query)
|
self.zilliz,
|
||||||
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
|
top_k=top_k,
|
||||||
return score_threshold_process(score_threshold, top_k, docs)
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
docs = retriever.get_relevant_documents(query)
|
||||||
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
|||||||
@ -42,55 +42,59 @@ def recreate_summary_vector_store(
|
|||||||
if not kb.exists() and not allow_empty_kb:
|
if not kb.exists() and not allow_empty_kb:
|
||||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||||
else:
|
else:
|
||||||
# 重新创建知识库
|
error_msg = f"could not recreate summary vector store because failed to access embed model."
|
||||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
if not kb.check_embed_model(error_msg):
|
||||||
kb_summary.drop_kb_summary()
|
yield {"code": 404, "msg": error_msg}
|
||||||
kb_summary.create_kb_summary()
|
else:
|
||||||
|
# 重新创建知识库
|
||||||
|
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||||
|
kb_summary.drop_kb_summary()
|
||||||
|
kb_summary.create_kb_summary()
|
||||||
|
|
||||||
llm = get_ChatOpenAI(
|
llm = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
local_wrap=True,
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
reduce_llm = get_ChatOpenAI(
|
reduce_llm = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
local_wrap=True,
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
# 文本摘要适配器
|
# 文本摘要适配器
|
||||||
summary = SummaryAdapter.form_summary(llm=llm,
|
summary = SummaryAdapter.form_summary(llm=llm,
|
||||||
reduce_llm=reduce_llm,
|
reduce_llm=reduce_llm,
|
||||||
overlap_size=OVERLAP_SIZE)
|
overlap_size=OVERLAP_SIZE)
|
||||||
files = list_files_from_folder(knowledge_base_name)
|
files = list_files_from_folder(knowledge_base_name)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
for i, file_name in enumerate(files):
|
for i, file_name in enumerate(files):
|
||||||
|
|
||||||
doc_infos = kb.list_docs(file_name=file_name)
|
doc_infos = kb.list_docs(file_name=file_name)
|
||||||
docs = summary.summarize(file_description=file_description,
|
docs = summary.summarize(file_description=file_description,
|
||||||
docs=doc_infos)
|
docs=doc_infos)
|
||||||
|
|
||||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||||
if status_kb_summary:
|
if status_kb_summary:
|
||||||
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
||||||
yield json.dumps({
|
yield json.dumps({
|
||||||
"code": 200,
|
"code": 200,
|
||||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||||
"total": len(files),
|
"total": len(files),
|
||||||
"finished": i + 1,
|
"finished": i + 1,
|
||||||
"doc": file_name,
|
"doc": file_name,
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
yield json.dumps({
|
yield json.dumps({
|
||||||
"code": 500,
|
"code": 500,
|
||||||
"msg": msg,
|
"msg": msg,
|
||||||
})
|
})
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return EventSourceResponse(output())
|
return EventSourceResponse(output())
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from chatchat.configs import (
|
from chatchat.configs import (
|
||||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE
|
CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose
|
||||||
)
|
)
|
||||||
from chatchat.server.knowledge_base.utils import (
|
from chatchat.server.knowledge_base.utils import (
|
||||||
get_file_path, list_kbs_from_folder,
|
get_file_path, list_kbs_from_folder,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from chatchat.configs import (
|
|||||||
TEXT_SPLITTER_NAME,
|
TEXT_SPLITTER_NAME,
|
||||||
)
|
)
|
||||||
import importlib
|
import importlib
|
||||||
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance
|
from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||||
import langchain_community.document_loaders
|
import langchain_community.document_loaders
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
|
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
|
||||||
|
|||||||
@ -244,6 +244,16 @@ def get_Embeddings(
|
|||||||
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
|
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool:
|
||||||
|
embeddings = get_Embeddings(embed_model=embed_model)
|
||||||
|
try:
|
||||||
|
embeddings.embed_query("this is a test")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_OpenAIClient(
|
def get_OpenAIClient(
|
||||||
platform_name: str = None,
|
platform_name: str = None,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
@ -696,7 +706,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
|||||||
'''
|
'''
|
||||||
创建一个临时目录,返回(路径,文件夹名称)
|
创建一个临时目录,返回(路径,文件夹名称)
|
||||||
'''
|
'''
|
||||||
from chatchat.configs.basic_config import BASE_TEMP_DIR
|
from chatchat.configs import BASE_TEMP_DIR
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
if id is not None: # 如果指定的临时目录已存在,直接返回
|
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||||
|
|||||||
@ -24,28 +24,28 @@ chat_box = ChatBox(
|
|||||||
|
|
||||||
|
|
||||||
def save_session():
|
def save_session():
|
||||||
'''save session state to chat context'''
|
"""save session state to chat context"""
|
||||||
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||||
|
|
||||||
|
|
||||||
def restore_session():
|
def restore_session():
|
||||||
'''restore sesstion state from chat context'''
|
"""restore sesstion state from chat context"""
|
||||||
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||||
|
|
||||||
|
|
||||||
def rerun():
|
def rerun():
|
||||||
'''
|
"""
|
||||||
save chat context before rerun
|
save chat context before rerun
|
||||||
'''
|
"""
|
||||||
save_session()
|
save_session()
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||||
'''
|
"""
|
||||||
返回消息历史。
|
返回消息历史。
|
||||||
content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要
|
content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def filter(msg):
|
def filter(msg):
|
||||||
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
|
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
|
||||||
@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
|||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
def upload_temp_docs(files, _api: ApiRequest) -> str:
|
def upload_temp_docs(files, _api: ApiRequest) -> str:
|
||||||
'''
|
"""
|
||||||
将文件上传到临时目录,用于文件对话
|
将文件上传到临时目录,用于文件对话
|
||||||
返回临时向量库ID
|
返回临时向量库ID
|
||||||
'''
|
"""
|
||||||
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
||||||
|
|
||||||
|
|
||||||
@ -157,11 +157,13 @@ def dialogue_page(
|
|||||||
tools = list_tools(api)
|
tools = list_tools(api)
|
||||||
tool_names = ["None"] + list(tools)
|
tool_names = ["None"] + list(tools)
|
||||||
if use_agent:
|
if use_agent:
|
||||||
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools")
|
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
|
||||||
|
# check_all=True, key="selected_tools")
|
||||||
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
|
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
|
||||||
key="selected_tools")
|
key="selected_tools")
|
||||||
else:
|
else:
|
||||||
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool")
|
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
|
||||||
|
# key="selected_tool")
|
||||||
selected_tool = st.selectbox("选择工具", tool_names,
|
selected_tool = st.selectbox("选择工具", tool_names,
|
||||||
format_func=lambda x: tools.get(x, {"title": "None"})["title"],
|
format_func=lambda x: tools.get(x, {"title": "None"})["title"],
|
||||||
key="selected_tool")
|
key="selected_tool")
|
||||||
@ -338,7 +340,7 @@ def dialogue_page(
|
|||||||
elif d.status == AgentStatus.agent_finish:
|
elif d.status == AgentStatus.agent_finish:
|
||||||
text = d.choices[0].delta.content or ""
|
text = d.choices[0].delta.content or ""
|
||||||
chat_box.update_msg(text.replace("\n", "\n\n"))
|
chat_box.update_msg(text.replace("\n", "\n\n"))
|
||||||
elif d.status == None: # not agent chat
|
elif d.status is None: # not agent chat
|
||||||
if getattr(d, "is_ref", False):
|
if getattr(d, "is_ref", False):
|
||||||
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
|
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
|
||||||
title="参考资料"))
|
title="参考资料"))
|
||||||
|
|||||||
@ -10,16 +10,18 @@ packages = [
|
|||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
chatchat = 'chatchat.startup:main'
|
chatchat = 'chatchat.startup:main'
|
||||||
|
chatchat-kb = 'chatchat.init_database:main'
|
||||||
|
chatchat-config = 'chatchat.config_work_space:main'
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<3.12,!=3.9.7"
|
python = ">=3.8.1,<3.12,!=3.9.7"
|
||||||
model-providers = "^0.3.0"
|
model-providers = "^0.3.0"
|
||||||
langchain = "0.1.5"
|
langchain = "0.1.17"
|
||||||
langchainhub = "0.1.14"
|
langchainhub = "0.1.14"
|
||||||
langchain-community = "0.0.17"
|
langchain-community = "0.0.36"
|
||||||
langchain-openai = "0.0.5"
|
langchain-openai = "0.0.5"
|
||||||
langchain-experimental = "0.0.50"
|
langchain-experimental = "0.0.58"
|
||||||
fastapi = "0.109.2"
|
fastapi = "~0.109.2"
|
||||||
sse_starlette = "~1.8.2"
|
sse_starlette = "~1.8.2"
|
||||||
nltk = "~3.8.1"
|
nltk = "~3.8.1"
|
||||||
uvicorn = ">=0.27.0.post1"
|
uvicorn = ">=0.27.0.post1"
|
||||||
@ -27,6 +29,8 @@ unstructured = "~0.11.0"
|
|||||||
python-magic-bin = {version = "*", platform = "win32"}
|
python-magic-bin = {version = "*", platform = "win32"}
|
||||||
SQLAlchemy = "~2.0.25"
|
SQLAlchemy = "~2.0.25"
|
||||||
faiss-cpu = "~1.7.4"
|
faiss-cpu = "~1.7.4"
|
||||||
|
cutword = "0.1.0"
|
||||||
|
rank_bm25 = "0.2.2"
|
||||||
# accelerate = "~0.24.1"
|
# accelerate = "~0.24.1"
|
||||||
# spacy = "~3.7.2"
|
# spacy = "~3.7.2"
|
||||||
PyMuPDF = "~1.23.16"
|
PyMuPDF = "~1.23.16"
|
||||||
@ -34,28 +38,29 @@ rapidocr_onnxruntime = "~1.3.8"
|
|||||||
requests = "~2.31.0"
|
requests = "~2.31.0"
|
||||||
pathlib = "~1.0.1"
|
pathlib = "~1.0.1"
|
||||||
pytest = "~7.4.3"
|
pytest = "~7.4.3"
|
||||||
pyjwt = "2.8.0"
|
pyjwt = "~2.8.0"
|
||||||
elasticsearch = "*"
|
elasticsearch = "*"
|
||||||
numexpr = ">=1" #test
|
numexpr = ">=1" #test
|
||||||
strsimpy = ">=0.2.1"
|
strsimpy = ">=0.2.1"
|
||||||
markdownify = ">=0.11.6"
|
markdownify = ">=0.11.6"
|
||||||
tqdm = ">=4.66.1"
|
tqdm = ">=4.66.1"
|
||||||
websockets = ">=12.0"
|
websockets = ">=12.0"
|
||||||
numpy = "1.24.4"
|
numpy = "~1.24.4"
|
||||||
pandas = "~1" # test
|
pandas = "~1" # test
|
||||||
pydantic = "2.6.4"
|
pydantic = "~2.6.4"
|
||||||
httpx = {version = ">=0.25.2", extras = ["brotli", "http2", "socks"]}
|
httpx = {version = ">=0.25.2", extras = ["brotli", "http2", "socks"]}
|
||||||
python-multipart = "0.0.9"
|
python-multipart = "0.0.9"
|
||||||
# webui
|
# webui
|
||||||
streamlit = "1.34.0"
|
streamlit = "1.34.0"
|
||||||
streamlit-option-menu = "0.3.12"
|
streamlit-option-menu = "0.3.12"
|
||||||
streamlit-antd-components = "0.3.1"
|
streamlit-antd-components = "0.3.1"
|
||||||
streamlit-chatbox = "1.1.12"
|
streamlit-chatbox = "1.1.12.post2"
|
||||||
streamlit-modal = "0.1.0"
|
streamlit-modal = "0.1.0"
|
||||||
streamlit-aggrid = "0.3.4.post3"
|
streamlit-aggrid = "0.3.4.post3"
|
||||||
streamlit-extras = "0.4.2"
|
streamlit-extras = "0.4.2"
|
||||||
xinference_client = { version = "^0.11.1", optional = true }
|
xinference_client = { version = "^0.11.1", optional = true }
|
||||||
zhipuai = { version = "^2.1.0", optional = true }
|
zhipuai = { version = "^2.1.0", optional = true }
|
||||||
|
pymysql = "^1.1.0"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
xinference = ["xinference_client"]
|
xinference = ["xinference_client"]
|
||||||
@ -254,7 +259,7 @@ ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
|
|||||||
|
|
||||||
[tool.poetry.plugins.dotenv]
|
[tool.poetry.plugins.dotenv]
|
||||||
ignore = "false"
|
ignore = "false"
|
||||||
location = ".env"
|
dotenv = "dotenv:plugin"
|
||||||
|
|
||||||
|
|
||||||
# https://python-poetry.org/docs/repositories/
|
# https://python-poetry.org/docs/repositories/
|
||||||
|
|||||||
90
libs/chatchat-server/tests/conftest.py
Normal file
90
libs/chatchat-server/tests/conftest.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""Configuration for unit tests."""
|
||||||
|
import logging
|
||||||
|
from importlib import util
|
||||||
|
from typing import Dict, List, Sequence
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest import Config, Function, Parser
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser: Parser) -> None:
|
||||||
|
"""Add custom command line options to pytest."""
|
||||||
|
parser.addoption(
|
||||||
|
"--only-extended",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--only-core",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run core tests. Never runs any extended tests.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||||
|
"""Add implementations for handling custom markers.
|
||||||
|
|
||||||
|
At the moment, this adds support for a custom `requires` marker.
|
||||||
|
|
||||||
|
The `requires` marker is used to denote tests that require one or more packages
|
||||||
|
to be installed to run. If the package is not installed, the test is skipped.
|
||||||
|
|
||||||
|
The `requires` marker syntax is:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@pytest.mark.requires("package1", "package2")
|
||||||
|
def test_something():
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
# Mapping from the name of a package to whether it is installed or not.
|
||||||
|
# Used to avoid repeated calls to `util.find_spec`
|
||||||
|
required_pkgs_info: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
only_extended = config.getoption("--only-extended") or False
|
||||||
|
only_core = config.getoption("--only-core") or False
|
||||||
|
|
||||||
|
if only_extended and only_core:
|
||||||
|
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
requires_marker = item.get_closest_marker("requires")
|
||||||
|
if requires_marker is not None:
|
||||||
|
if only_core:
|
||||||
|
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Iterate through the list of required packages
|
||||||
|
required_pkgs = requires_marker.args
|
||||||
|
for pkg in required_pkgs:
|
||||||
|
# If we haven't yet checked whether the pkg is installed
|
||||||
|
# let's check it and store the result.
|
||||||
|
if pkg not in required_pkgs_info:
|
||||||
|
try:
|
||||||
|
installed = util.find_spec(pkg) is not None
|
||||||
|
except Exception:
|
||||||
|
installed = False
|
||||||
|
required_pkgs_info[pkg] = installed
|
||||||
|
|
||||||
|
if not required_pkgs_info[pkg]:
|
||||||
|
if only_extended:
|
||||||
|
pytest.fail(
|
||||||
|
f"Package `{pkg}` is not installed but is required for "
|
||||||
|
f"extended tests. Please install the given package and "
|
||||||
|
f"try again.",
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If the package is not installed, we immediately break
|
||||||
|
# and mark the test as skipped.
|
||||||
|
item.add_marker(
|
||||||
|
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if only_extended:
|
||||||
|
item.add_marker(
|
||||||
|
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
56
libs/chatchat-server/tests/unit_tests/config/test_config.py
Normal file
56
libs/chatchat-server/tests/unit_tests/config/test_config.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_factory_def():
|
||||||
|
test_config_factory = ConfigBasicFactory()
|
||||||
|
config: ConfigBasic = test_config_factory.get_config()
|
||||||
|
assert config is not None
|
||||||
|
assert config.log_verbose is False
|
||||||
|
assert config.CHATCHAT_ROOT is not None
|
||||||
|
assert config.DATA_PATH is not None
|
||||||
|
assert config.IMG_DIR is not None
|
||||||
|
assert config.NLTK_DATA_PATH is not None
|
||||||
|
assert config.LOG_FORMAT is not None
|
||||||
|
assert config.LOG_PATH is not None
|
||||||
|
assert config.MEDIA_PATH is not None
|
||||||
|
|
||||||
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
||||||
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
||||||
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace():
|
||||||
|
config_workspace = ConfigWorkSpace()
|
||||||
|
assert config_workspace.get_config() is not None
|
||||||
|
base_root = os.path.join(Path(__file__).absolute().parent, "chatchat")
|
||||||
|
config_workspace.set_data_path(os.path.join(base_root, "data"))
|
||||||
|
config_workspace.set_log_verbose(True)
|
||||||
|
config_workspace.set_log_format(" %(message)s")
|
||||||
|
|
||||||
|
config: ConfigBasic = config_workspace.get_config()
|
||||||
|
assert config.log_verbose is True
|
||||||
|
assert config.DATA_PATH == os.path.join(base_root, "data")
|
||||||
|
assert config.IMG_DIR is not None
|
||||||
|
assert config.NLTK_DATA_PATH == os.path.join(base_root, "data", "nltk_data")
|
||||||
|
assert config.LOG_FORMAT == " %(message)s"
|
||||||
|
assert config.LOG_PATH == os.path.join(base_root, "data", "logs")
|
||||||
|
assert config.MEDIA_PATH == os.path.join(base_root, "data", "media")
|
||||||
|
|
||||||
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
||||||
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
||||||
|
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
||||||
|
config_workspace.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_workspace_default():
|
||||||
|
from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH)
|
||||||
|
assert log_verbose is False
|
||||||
|
assert DATA_PATH is not None
|
||||||
|
assert IMG_DIR is not None
|
||||||
|
assert NLTK_DATA_PATH is not None
|
||||||
|
assert LOG_FORMAT is not None
|
||||||
|
assert LOG_PATH is not None
|
||||||
|
assert MEDIA_PATH is not None
|
||||||
1
libs/model-providers/.env
Normal file
1
libs/model-providers/.env
Normal file
@ -0,0 +1 @@
|
|||||||
|
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
|
||||||
@ -1,26 +1,26 @@
|
|||||||
import redis
|
# import redis
|
||||||
from redis.connection import Connection, SSLConnection
|
# from redis.connection import Connection, SSLConnection
|
||||||
|
#
|
||||||
redis_client = redis.Redis()
|
# redis_client = redis.Redis()
|
||||||
|
#
|
||||||
|
#
|
||||||
def init_app(app):
|
# def init_app(app):
|
||||||
connection_class = Connection
|
# connection_class = Connection
|
||||||
if app.config.get("REDIS_USE_SSL", False):
|
# if app.config.get("REDIS_USE_SSL", False):
|
||||||
connection_class = SSLConnection
|
# connection_class = SSLConnection
|
||||||
|
#
|
||||||
redis_client.connection_pool = redis.ConnectionPool(
|
# redis_client.connection_pool = redis.ConnectionPool(
|
||||||
**{
|
# **{
|
||||||
"host": app.config.get("REDIS_HOST", "localhost"),
|
# "host": app.config.get("REDIS_HOST", "localhost"),
|
||||||
"port": app.config.get("REDIS_PORT", 6379),
|
# "port": app.config.get("REDIS_PORT", 6379),
|
||||||
"username": app.config.get("REDIS_USERNAME", None),
|
# "username": app.config.get("REDIS_USERNAME", None),
|
||||||
"password": app.config.get("REDIS_PASSWORD", None),
|
# "password": app.config.get("REDIS_PASSWORD", None),
|
||||||
"db": app.config.get("REDIS_DB", 0),
|
# "db": app.config.get("REDIS_DB", 0),
|
||||||
"encoding": "utf-8",
|
# "encoding": "utf-8",
|
||||||
"encoding_errors": "strict",
|
# "encoding_errors": "strict",
|
||||||
"decode_responses": False,
|
# "decode_responses": False,
|
||||||
},
|
# },
|
||||||
connection_class=connection_class,
|
# connection_class=connection_class,
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
app.extensions["redis"] = redis_client
|
# app.extensions["redis"] = redis_client
|
||||||
|
|||||||
@ -7,20 +7,19 @@ readme = "README.md"
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<3.12,!=3.9.7"
|
python = ">=3.8.1,<3.12,!=3.9.7"
|
||||||
transformers = "4.31.0"
|
transformers = "~4.31.0"
|
||||||
fastapi = "^0.109.2"
|
fastapi = "~0.109.2"
|
||||||
uvicorn = ">=0.27.0.post1"
|
uvicorn = ">=0.27.0.post1"
|
||||||
sse-starlette = "^1.8.2"
|
sse-starlette = "~1.8.2"
|
||||||
pyyaml = "6.0.1"
|
pyyaml = "~6.0.1"
|
||||||
pydantic = "2.6.4"
|
pydantic ="~2.6.4"
|
||||||
redis = "4.5.4"
|
|
||||||
# config manage
|
# config manage
|
||||||
omegaconf = "2.0.6"
|
omegaconf = "~2.0.6"
|
||||||
# modle_runtime
|
# modle_runtime
|
||||||
openai = "1.13.3"
|
openai = "~1.13.3"
|
||||||
tiktoken = "0.5.2"
|
tiktoken = "~0.5.2"
|
||||||
pydub = "0.25.1"
|
pydub = "~0.25.1"
|
||||||
boto3 = "1.28.17"
|
boto3 = "~1.28.17"
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
[tool.poetry.group.test.dependencies]
|
||||||
# The only dependencies that should be added are
|
# The only dependencies that should be added are
|
||||||
@ -207,7 +206,7 @@ ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
|
|||||||
|
|
||||||
[tool.poetry.plugins.dotenv]
|
[tool.poetry.plugins.dotenv]
|
||||||
ignore = "false"
|
ignore = "false"
|
||||||
location = ".env"
|
dotenv = "dotenv:plugin"
|
||||||
|
|
||||||
# https://python-poetry.org/docs/repositories/
|
# https://python-poetry.org/docs/repositories/
|
||||||
[[tool.poetry.source]]
|
[[tool.poetry.source]]
|
||||||
|
|||||||
@ -25,7 +25,19 @@ langchain-chatchat = { path = "libs/chatchat-server", develop = true }
|
|||||||
ipykernel = "^6.29.2"
|
ipykernel = "^6.29.2"
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
[tool.poetry.group.test.dependencies]
|
||||||
|
pytest = "^7.3.0"
|
||||||
|
pytest-cov = "^4.0.0"
|
||||||
|
pytest-dotenv = "^0.5.2"
|
||||||
|
duckdb-engine = "^0.9.2"
|
||||||
|
pytest-watcher = "^0.2.6"
|
||||||
|
freezegun = "^1.2.2"
|
||||||
|
responses = "^0.22.0"
|
||||||
|
pytest-asyncio = "^0.23.2"
|
||||||
|
lark = "^1.1.5"
|
||||||
|
pytest-mock = "^3.10.0"
|
||||||
|
pytest-socket = "^0.6.0"
|
||||||
|
syrupy = "^4.0.2"
|
||||||
|
requests-mock = "^1.11.0"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
extend-include = ["*.ipynb"]
|
extend-include = ["*.ipynb"]
|
||||||
@ -48,7 +60,7 @@ extend-exclude = [
|
|||||||
|
|
||||||
[tool.poetry.plugins.dotenv]
|
[tool.poetry.plugins.dotenv]
|
||||||
ignore = "false"
|
ignore = "false"
|
||||||
location = ".env"
|
dotenv = "dotenv:plugin"
|
||||||
|
|
||||||
|
|
||||||
# https://python-poetry.org/docs/repositories/
|
# https://python-poetry.org/docs/repositories/
|
||||||
|
|||||||
@ -133,6 +133,8 @@ model_names = list(regs[model_type].keys())
|
|||||||
model_name = cols[1].selectbox("模型名称:", model_names)
|
model_name = cols[1].selectbox("模型名称:", model_names)
|
||||||
|
|
||||||
cur_reg = regs[model_type][model_name]["reg"]
|
cur_reg = regs[model_type][model_name]["reg"]
|
||||||
|
model_format = None
|
||||||
|
model_quant = None
|
||||||
|
|
||||||
if model_type == "LLM":
|
if model_type == "LLM":
|
||||||
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
|
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
|
||||||
@ -217,7 +219,7 @@ cols = st.columns(3)
|
|||||||
|
|
||||||
if cols[0].button("设置模型缓存"):
|
if cols[0].button("设置模型缓存"):
|
||||||
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
||||||
cur_spec.model_uri = local_path
|
cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri
|
||||||
if os.path.isdir(cache_dir):
|
if os.path.isdir(cache_dir):
|
||||||
os.rmdir(cache_dir)
|
os.rmdir(cache_dir)
|
||||||
if model_type == "LLM":
|
if model_type == "LLM":
|
||||||
@ -250,10 +252,10 @@ if cols[2].button("注册为自定义模型"):
|
|||||||
if model_type == "LLM":
|
if model_type == "LLM":
|
||||||
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
|
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
|
||||||
cur_family.model_family = "other"
|
cur_family.model_family = "other"
|
||||||
model_definition = cur_family.json(indent=2, ensure_ascii=False)
|
model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
|
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
|
||||||
model_definition = cur_spec.json(indent=2, ensure_ascii=False)
|
model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False)
|
||||||
client.register_model(
|
client.register_model(
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
model=model_definition,
|
model=model_definition,
|
||||||
@ -262,4 +264,3 @@ if cols[2].button("注册为自定义模型"):
|
|||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
st.error("必须输入存在的绝对路径")
|
st.error("必须输入存在的绝对路径")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user