mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
Merge branch 'chatchat-space:dev' into dev
This commit is contained in:
commit
d189bd182f
@ -15,9 +15,11 @@
|
||||
|
||||
Install Poetry: [documentation on how to install it.](https://python-poetry.org/docs/#installing-with-pipx)
|
||||
|
||||
> 友情提示 不想安装pipx可以用pip安装poetry,(Tips:如果你没有其它poetry的项目
|
||||
> 注意: 如果您使用 Conda 或 Pyenv 作为您的环境/包管理器,在安装Poetry之后,
|
||||
> 使用如下命令使 Poetry 使用 virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`)
|
||||
|
||||
|
||||
#### 本地开发环境安装
|
||||
|
||||
- 选择主项目目录
|
||||
|
||||
1
libs/chatchat-server/.env
Normal file
1
libs/chatchat-server/.env
Normal file
@ -0,0 +1 @@
|
||||
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
|
||||
@ -29,6 +29,11 @@ vim model_providers.yaml
|
||||
>
|
||||
> 详细配置请参考[README.md](../model-providers/README.md)
|
||||
|
||||
- 初始化知识库
|
||||
```shell
|
||||
chatchat-kb -r
|
||||
```
|
||||
|
||||
- 启动服务
|
||||
```shell
|
||||
chatchat -a
|
||||
|
||||
47
libs/chatchat-server/chatchat/config_work_space.py
Normal file
47
libs/chatchat-server/chatchat/config_work_space.py
Normal file
@ -0,0 +1,47 @@
|
||||
from chatchat.configs import config_workspace as workspace
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="指令` chatchat-config` 工作空间配置")
|
||||
# 只能选择true或false
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
choices=["true", "false"],
|
||||
help="是否开启详细日志"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--data",
|
||||
help="数据存放路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--format",
|
||||
help="日志格式"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear",
|
||||
action="store_true",
|
||||
help="清除配置"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
if args.verbose.lower() == "true":
|
||||
workspace.set_log_verbose(True)
|
||||
else:
|
||||
workspace.set_log_verbose(False)
|
||||
if args.data:
|
||||
workspace.set_data_path(args.data)
|
||||
if args.format:
|
||||
workspace.set_log_format(args.format)
|
||||
if args.clear:
|
||||
workspace.clear()
|
||||
print(workspace.get_config())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,8 +1,10 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger()
|
||||
@ -38,7 +40,6 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
|
||||
)
|
||||
user_import = False
|
||||
if user_import:
|
||||
|
||||
# Dynamic loading {config}.py file
|
||||
py_path = os.path.join(user_config_path, import_config_mod + ".py")
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
@ -69,7 +70,7 @@ def _import_config_mod_load(import_config_mod: str) -> Dict:
|
||||
)
|
||||
raise RuntimeError(f"Failed to load user config from {user_config_path}")
|
||||
# 当前文件路径
|
||||
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
|
||||
py_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), import_config_mod + ".py")
|
||||
|
||||
spec = importlib.util.spec_from_file_location(f"*",
|
||||
py_path)
|
||||
@ -95,75 +96,108 @@ CONFIG_IMPORTS = {
|
||||
}
|
||||
|
||||
|
||||
def _import_ConfigBasic() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
ConfigBasic = load_mod(basic_config_load.get("module"), "ConfigBasic")
|
||||
|
||||
return ConfigBasic
|
||||
|
||||
|
||||
def _import_ConfigBasicFactory() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
ConfigBasicFactory = load_mod(basic_config_load.get("module"), "ConfigBasicFactory")
|
||||
|
||||
return ConfigBasicFactory
|
||||
|
||||
|
||||
def _import_ConfigWorkSpace() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
ConfigWorkSpace = load_mod(basic_config_load.get("module"), "ConfigWorkSpace")
|
||||
|
||||
return ConfigWorkSpace
|
||||
|
||||
|
||||
def _import_config_workspace() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
return config_workspace
|
||||
|
||||
def _import_log_verbose() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
log_verbose = load_mod(basic_config_load.get("module"), "log_verbose")
|
||||
|
||||
return log_verbose
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
return config_workspace.get_config().log_verbose
|
||||
|
||||
|
||||
def _import_chatchat_root() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
CHATCHAT_ROOT = load_mod(basic_config_load.get("module"), "CHATCHAT_ROOT")
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return CHATCHAT_ROOT
|
||||
return config_workspace.get_config().CHATCHAT_ROOT
|
||||
|
||||
|
||||
def _import_data_path() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
DATA_PATH = load_mod(basic_config_load.get("module"), "DATA_PATH")
|
||||
|
||||
return DATA_PATH
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
return config_workspace.get_config().DATA_PATH
|
||||
|
||||
|
||||
def _import_img_dir() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
IMG_DIR = load_mod(basic_config_load.get("module"), "IMG_DIR")
|
||||
|
||||
return IMG_DIR
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return config_workspace.get_config().IMG_DIR
|
||||
|
||||
|
||||
def _import_nltk_data_path() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
NLTK_DATA_PATH = load_mod(basic_config_load.get("module"), "NLTK_DATA_PATH")
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return NLTK_DATA_PATH
|
||||
return config_workspace.get_config().NLTK_DATA_PATH
|
||||
|
||||
|
||||
def _import_log_format() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
LOG_FORMAT = load_mod(basic_config_load.get("module"), "LOG_FORMAT")
|
||||
|
||||
return LOG_FORMAT
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return config_workspace.get_config().LOG_FORMAT
|
||||
|
||||
|
||||
def _import_log_path() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
LOG_PATH = load_mod(basic_config_load.get("module"), "LOG_PATH")
|
||||
|
||||
return LOG_PATH
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return config_workspace.get_config().LOG_PATH
|
||||
|
||||
|
||||
def _import_media_path() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
MEDIA_PATH = load_mod(basic_config_load.get("module"), "MEDIA_PATH")
|
||||
|
||||
return MEDIA_PATH
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
|
||||
return config_workspace.get_config().MEDIA_PATH
|
||||
|
||||
|
||||
def _import_base_temp_dir() -> Any:
|
||||
basic_config_load = CONFIG_IMPORTS.get("_basic_config.py")
|
||||
load_mod = basic_config_load.get("load_mod")
|
||||
BASE_TEMP_DIR = load_mod(basic_config_load.get("module"), "BASE_TEMP_DIR")
|
||||
|
||||
return BASE_TEMP_DIR
|
||||
config_workspace = load_mod(basic_config_load.get("module"), "config_workspace")
|
||||
return config_workspace.get_config().BASE_TEMP_DIR
|
||||
|
||||
|
||||
def _import_default_knowledge_base() -> Any:
|
||||
@ -285,6 +319,7 @@ def _import_db_root_path() -> Any:
|
||||
|
||||
return DB_ROOT_PATH
|
||||
|
||||
|
||||
def _import_sqlalchemy_database_uri() -> Any:
|
||||
kb_config_load = CONFIG_IMPORTS.get("_kb_config.py")
|
||||
load_mod = kb_config_load.get("load_mod")
|
||||
@ -478,7 +513,15 @@ def _import_api_server() -> Any:
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "log_verbose":
|
||||
if name == "ConfigBasic":
|
||||
return _import_ConfigBasic()
|
||||
elif name == "ConfigBasicFactory":
|
||||
return _import_ConfigBasicFactory()
|
||||
elif name == "ConfigWorkSpace":
|
||||
return _import_ConfigWorkSpace()
|
||||
elif name == "config_workspace":
|
||||
return _import_config_workspace()
|
||||
elif name == "log_verbose":
|
||||
return _import_log_verbose()
|
||||
elif name == "CHATCHAT_ROOT":
|
||||
return _import_chatchat_root()
|
||||
@ -578,6 +621,7 @@ VERSION = "v0.3.0-preview"
|
||||
|
||||
__all__ = [
|
||||
"VERSION",
|
||||
"config_workspace",
|
||||
"log_verbose",
|
||||
"CHATCHAT_ROOT",
|
||||
"DATA_PATH",
|
||||
@ -626,5 +670,8 @@ __all__ = [
|
||||
"WEBUI_SERVER",
|
||||
"API_SERVER",
|
||||
|
||||
"ConfigBasic",
|
||||
"ConfigBasicFactory",
|
||||
"ConfigWorkSpace",
|
||||
|
||||
]
|
||||
|
||||
@ -1,55 +1,180 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
import langchain
|
||||
|
||||
|
||||
# 是否显示详细日志
|
||||
log_verbose = False
|
||||
langchain.verbose = False
|
||||
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
# chatchat 项目根目录
|
||||
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
||||
|
||||
# 用户数据根目录
|
||||
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
|
||||
if not os.path.exists(DATA_PATH):
|
||||
os.mkdir(DATA_PATH)
|
||||
|
||||
# 项目相关图片
|
||||
IMG_DIR = os.path.join(CHATCHAT_ROOT, "img")
|
||||
if not os.path.exists(IMG_DIR):
|
||||
os.mkdir(IMG_DIR)
|
||||
|
||||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(DATA_PATH, "nltk_data")
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
# 日志格式
|
||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logging.basicConfig(format=LOG_FORMAT)
|
||||
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.path.join(DATA_PATH, "logs")
|
||||
if not os.path.exists(LOG_PATH):
|
||||
os.mkdir(LOG_PATH)
|
||||
class ConfigBasic:
|
||||
log_verbose: bool
|
||||
"""是否开启日志详细信息"""
|
||||
CHATCHAT_ROOT: str
|
||||
"""项目根目录"""
|
||||
DATA_PATH: str
|
||||
"""用户数据根目录"""
|
||||
IMG_DIR: str
|
||||
"""项目相关图片"""
|
||||
NLTK_DATA_PATH: str
|
||||
"""nltk 模型存储路径"""
|
||||
LOG_FORMAT: str
|
||||
"""日志格式"""
|
||||
LOG_PATH: str
|
||||
"""日志存储路径"""
|
||||
MEDIA_PATH: str
|
||||
"""模型生成内容(图片、视频、音频等)保存位置"""
|
||||
BASE_TEMP_DIR: str
|
||||
"""临时文件目录,主要用于文件对话"""
|
||||
|
||||
# 模型生成内容(图片、视频、音频等)保存位置
|
||||
MEDIA_PATH = os.path.join(DATA_PATH, "media")
|
||||
if not os.path.exists(MEDIA_PATH):
|
||||
os.mkdir(MEDIA_PATH)
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "image"))
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "audio"))
|
||||
os.mkdir(os.path.join(MEDIA_PATH, "video"))
|
||||
def __str__(self):
|
||||
return f"ConfigBasic(log_verbose={self.log_verbose}, CHATCHAT_ROOT={self.CHATCHAT_ROOT}, DATA_PATH={self.DATA_PATH}, IMG_DIR={self.IMG_DIR}, NLTK_DATA_PATH={self.NLTK_DATA_PATH}, LOG_FORMAT={self.LOG_FORMAT}, LOG_PATH={self.LOG_PATH}, MEDIA_PATH={self.MEDIA_PATH}, BASE_TEMP_DIR={self.BASE_TEMP_DIR})"
|
||||
|
||||
# 临时文件目录,主要用于文件对话
|
||||
BASE_TEMP_DIR = os.path.join(DATA_PATH, "temp")
|
||||
if not os.path.exists(BASE_TEMP_DIR):
|
||||
os.mkdir(BASE_TEMP_DIR)
|
||||
|
||||
class ConfigBasicFactory:
|
||||
"""Basic config for ChatChat """
|
||||
|
||||
def __init__(self):
|
||||
# 日志格式
|
||||
self.LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logging.basicConfig(format=self.LOG_FORMAT)
|
||||
self.LOG_VERBOSE = False
|
||||
self.CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
||||
# 用户数据根目录
|
||||
self.DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||
self._DATA_PATH = os.path.join(self.CHATCHAT_ROOT, "data")
|
||||
if not os.path.exists(self._DATA_PATH):
|
||||
os.makedirs(self._DATA_PATH, exist_ok=True)
|
||||
|
||||
self._init_data_dir()
|
||||
|
||||
# 项目相关图片
|
||||
self.IMG_DIR = os.path.join(self.CHATCHAT_ROOT, "img")
|
||||
if not os.path.exists(self.IMG_DIR):
|
||||
os.makedirs(self.IMG_DIR, exist_ok=True)
|
||||
|
||||
def log_verbose(self, verbose: bool):
|
||||
self.LOG_VERBOSE = verbose
|
||||
|
||||
def data_path(self, path: str):
|
||||
self.DATA_PATH = path
|
||||
if not os.path.exists(self.DATA_PATH):
|
||||
os.makedirs(self.DATA_PATH, exist_ok=True)
|
||||
# 复制_DATA_PATH数据到DATA_PATH
|
||||
if self._DATA_PATH != self.DATA_PATH:
|
||||
os.system(f"cp -r {self._DATA_PATH}/* {self.DATA_PATH}")
|
||||
|
||||
self._init_data_dir()
|
||||
|
||||
def log_format(self, log_format: str):
|
||||
self.LOG_FORMAT = log_format
|
||||
logging.basicConfig(format=self.LOG_FORMAT)
|
||||
|
||||
def _init_data_dir(self):
|
||||
logger.info(f"Init data dir: {self.DATA_PATH}")
|
||||
# nltk 模型存储路径
|
||||
self.NLTK_DATA_PATH = os.path.join(self.DATA_PATH, "nltk_data")
|
||||
import nltk
|
||||
nltk.data.path = [self.NLTK_DATA_PATH] + nltk.data.path
|
||||
# 日志存储路径
|
||||
self.LOG_PATH = os.path.join(self.DATA_PATH, "logs")
|
||||
if not os.path.exists(self.LOG_PATH):
|
||||
os.makedirs(self.LOG_PATH, exist_ok=True)
|
||||
|
||||
# 模型生成内容(图片、视频、音频等)保存位置
|
||||
self.MEDIA_PATH = os.path.join(self.DATA_PATH, "media")
|
||||
if not os.path.exists(self.MEDIA_PATH):
|
||||
os.makedirs(self.MEDIA_PATH, exist_ok=True)
|
||||
os.makedirs(os.path.join(self.MEDIA_PATH, "image"), exist_ok=True)
|
||||
os.makedirs(os.path.join(self.MEDIA_PATH, "audio"), exist_ok=True)
|
||||
os.makedirs(os.path.join(self.MEDIA_PATH, "video"), exist_ok=True)
|
||||
|
||||
# 临时文件目录,主要用于文件对话
|
||||
self.BASE_TEMP_DIR = os.path.join(self.DATA_PATH, "temp")
|
||||
if not os.path.exists(self.BASE_TEMP_DIR):
|
||||
os.makedirs(self.BASE_TEMP_DIR, exist_ok=True)
|
||||
|
||||
logger.info(f"Init data dir: {self.DATA_PATH} success.")
|
||||
|
||||
def get_config(self) -> ConfigBasic:
|
||||
config = ConfigBasic()
|
||||
config.log_verbose = self.LOG_VERBOSE
|
||||
config.CHATCHAT_ROOT = self.CHATCHAT_ROOT
|
||||
config.DATA_PATH = self.DATA_PATH
|
||||
config.IMG_DIR = self.IMG_DIR
|
||||
config.NLTK_DATA_PATH = self.NLTK_DATA_PATH
|
||||
config.LOG_FORMAT = self.LOG_FORMAT
|
||||
config.LOG_PATH = self.LOG_PATH
|
||||
config.MEDIA_PATH = self.MEDIA_PATH
|
||||
config.BASE_TEMP_DIR = self.BASE_TEMP_DIR
|
||||
return config
|
||||
|
||||
|
||||
class ConfigWorkSpace:
|
||||
"""
|
||||
工作空间的配置预设,提供ConfigBasic建造方法产生实例。
|
||||
该类的实例对象用于存储工作空间的配置信息,如工作空间的路径等
|
||||
工作空间的配置信息存储在用户的家目录下的.config/chatchat/workspace/workspace_config.json文件中。
|
||||
注意:不存在则读取默认
|
||||
"""
|
||||
_config_factory: ConfigBasicFactory = ConfigBasicFactory()
|
||||
|
||||
def __init__(self):
|
||||
self.workspace = os.path.join(os.path.expanduser("~"), ".config", "chatchat/workspace")
|
||||
if not os.path.exists(self.workspace):
|
||||
os.makedirs(self.workspace, exist_ok=True)
|
||||
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
|
||||
# 初始化工作空间配置,转换成json格式,实现ConfigBasic的实例化
|
||||
|
||||
config_json = self._load_config()
|
||||
|
||||
if config_json:
|
||||
|
||||
_config_factory = ConfigBasicFactory()
|
||||
if config_json.get("log_verbose"):
|
||||
_config_factory.log_verbose(config_json.get("log_verbose"))
|
||||
if config_json.get("DATA_PATH"):
|
||||
_config_factory.data_path(config_json.get("DATA_PATH"))
|
||||
if config_json.get("LOG_FORMAT"):
|
||||
_config_factory.log_format(config_json.get("LOG_FORMAT"))
|
||||
|
||||
self._config_factory = _config_factory
|
||||
|
||||
def get_config(self) -> ConfigBasic:
|
||||
return self._config_factory.get_config()
|
||||
|
||||
def set_log_verbose(self, verbose: bool):
|
||||
self._config_factory.log_verbose(verbose)
|
||||
self._store_config()
|
||||
|
||||
def set_data_path(self, path: str):
|
||||
self._config_factory.data_path(path)
|
||||
self._store_config()
|
||||
|
||||
def set_log_format(self, log_format: str):
|
||||
self._config_factory.log_format(log_format)
|
||||
self._store_config()
|
||||
|
||||
def clear(self):
|
||||
logger.info("Clear workspace config.")
|
||||
os.remove(self.workspace_config)
|
||||
|
||||
def _load_config(self):
|
||||
try:
|
||||
with open(self.workspace_config, "r") as f:
|
||||
return json.loads(f.read())
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def _store_config(self):
|
||||
with open(self.workspace_config, "w") as f:
|
||||
config = self._config_factory.get_config()
|
||||
config_json = {
|
||||
"log_verbose": config.log_verbose,
|
||||
"CHATCHAT_ROOT": config.CHATCHAT_ROOT,
|
||||
"DATA_PATH": config.DATA_PATH,
|
||||
"LOG_FORMAT": config.LOG_FORMAT
|
||||
}
|
||||
f.write(json.dumps(config_json, indent=4, ensure_ascii=False))
|
||||
|
||||
|
||||
config_workspace: ConfigWorkSpace = ConfigWorkSpace()
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# chatchat 项目根目录
|
||||
CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent)
|
||||
import sys
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
|
||||
from _basic_config import config_workspace
|
||||
|
||||
# 用户数据根目录
|
||||
DATA_PATH = os.path.join(CHATCHAT_ROOT, "data")
|
||||
DATA_PATH = config_workspace.get_config().DATA_PATH
|
||||
|
||||
# 默认使用的知识库
|
||||
DEFAULT_KNOWLEDGE_BASE = "samples"
|
||||
|
||||
@ -118,6 +118,25 @@ MODEL_PLATFORMS = [
|
||||
"tts_models": [],
|
||||
},
|
||||
|
||||
{
|
||||
"platform_name": "xinference",
|
||||
"platform_type": "xinference",
|
||||
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||
"api_key": "EMPTY",
|
||||
"api_concurrencies": 5,
|
||||
"llm_models": [
|
||||
"glm-4",
|
||||
"qwen2-instruct",
|
||||
"qwen1.5-chat",
|
||||
],
|
||||
"embed_models": [
|
||||
"bge-large-zh-v1.5",
|
||||
],
|
||||
"image_models": [],
|
||||
"reranking_models": [],
|
||||
"speech2text_models": [],
|
||||
"tts_models": [],
|
||||
},
|
||||
|
||||
]
|
||||
|
||||
@ -130,7 +149,7 @@ TOOL_CONFIG = {
|
||||
"search_local_knowledgebase": {
|
||||
"use": False,
|
||||
"top_k": 3,
|
||||
"score_threshold": 1,
|
||||
"score_threshold": 1.0,
|
||||
"conclude_prompt": {
|
||||
"with_result":
|
||||
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题",'
|
||||
@ -208,5 +227,34 @@ TOOL_CONFIG = {
|
||||
"text2images": {
|
||||
"use": False,
|
||||
},
|
||||
|
||||
# text2sql使用建议
|
||||
# 1、因大模型生成的sql可能与预期有偏差,请务必在测试环境中进行充分测试、评估;
|
||||
# 2、生产环境中,对于查询操作,由于不确定查询效率,推荐数据库采用主从数据库架构,让text2sql连接从数据库,防止可能的慢查询影响主业务;
|
||||
# 3、对于写操作应保持谨慎,如不需要写操作,设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作;
|
||||
# 4、text2sql与大模型在意图理解、sql转换等方面的能力有关,可切换不同大模型进行测试;
|
||||
# 5、数据库表名、字段名应与其实际作用保持一致、容易理解,且应对数据库表名、字段进行详细的备注说明,帮助大模型更好理解数据库结构;
|
||||
# 6、若现有数据库表名难于让大模型理解,可配置下面table_comments字段,补充说明某些表的作用。
|
||||
"text2sql": {
|
||||
"use": False,
|
||||
# SQLAlchemy连接字符串,支持的数据库有:
|
||||
# crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb
|
||||
# 不同的数据库请查询SQLAlchemy,修改sqlalchemy_connect_str,配置对应的数据库连接,如sqlite为sqlite:///数据库文件路径,下面示例为mysql
|
||||
# 如提示缺少对应数据库的驱动,请自行通过poetry安装
|
||||
"sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e",
|
||||
# 务必评估是否需要开启read_only,开启后会对sql语句进行检查,请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求
|
||||
# 优先推荐从数据库层面对用户权限进行限制
|
||||
"read_only": False,
|
||||
#限定返回的行数
|
||||
"top_k":50,
|
||||
#是否返回中间步骤
|
||||
"return_intermediate_steps": True,
|
||||
#如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表
|
||||
"table_names":[],
|
||||
#对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。
|
||||
"table_comments":{
|
||||
# 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明
|
||||
# "tableA":"这是一个用户表,存储了用户的基本信息",
|
||||
# "tanleB":"角色表",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@ -20,16 +20,16 @@
|
||||
|
||||
xinference:
|
||||
model_credential:
|
||||
- model: 'chatglm3-6b'
|
||||
- model: 'glm-4'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'chatglm3-6b'
|
||||
- model: 'Qwen1.5-14B-Chat'
|
||||
model_uid: 'glm-4'
|
||||
- model: 'qwen1.5-chat'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'Qwen1.5-14B-Chat'
|
||||
model_uid: 'qwen1.5-chat'
|
||||
- model: 'bge-large-zh-v1.5'
|
||||
model_type: 'embeddings'
|
||||
model_credentials:
|
||||
@ -46,4 +46,4 @@ xinference:
|
||||
# model_type: 'llm'
|
||||
# model_credentials:
|
||||
# base_url: 'http://172.21.192.1:11434'
|
||||
# mode: 'completion'
|
||||
# mode: 'completion'
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
|
||||
from datetime import datetime
|
||||
import multiprocessing as mp
|
||||
from typing import Dict
|
||||
|
||||
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||
folder2db, prune_db_docs, prune_folder_files)
|
||||
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS
|
||||
import multiprocessing as mp
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def run_init_model_provider(
|
||||
@ -34,7 +33,7 @@ def run_init_model_provider(
|
||||
provider_port=provider_port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
|
||||
@ -186,3 +185,7 @@ if __name__ == "__main__":
|
||||
|
||||
for p in processes.values():
|
||||
logger.info("Process status: %s", p)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -10,3 +10,4 @@ from .text2image import text2images
|
||||
|
||||
from .vqa_processor import vqa_processor
|
||||
from .aqa_processor import aqa_processor
|
||||
from .text2sql import text2sql
|
||||
@ -1,13 +1,14 @@
|
||||
from urllib.parse import urlencode
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
from chatchat.server.agent.tools_factory.tools_registry import regist_tool, BaseToolOutput
|
||||
from chatchat.server.knowledge_base.kb_api import list_kbs
|
||||
from chatchat.server.knowledge_base.kb_doc_api import search_docs, DocumentWithVSId
|
||||
from chatchat.configs import KB_INFO
|
||||
|
||||
|
||||
template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]."
|
||||
template = ("Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on "
|
||||
"this knowledge use this tool. The 'database' should be one of the above [{key}].")
|
||||
KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()])
|
||||
template_knowledge = template.format(KB_info=KB_info_str, key="samples")
|
||||
|
||||
@ -49,7 +50,7 @@ def search_local_knowledgebase(
|
||||
database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
|
||||
query: str = Field(description="Query for Knowledge Search"),
|
||||
):
|
||||
''''''
|
||||
""""""
|
||||
tool_config = get_tool_config("search_local_knowledgebase")
|
||||
ret = search_knowledgebase(query=query, database=database, config=tool_config)
|
||||
return KBToolOutput(ret, database=database)
|
||||
|
||||
@ -0,0 +1,103 @@
|
||||
from langchain.utilities import SQLDatabase
|
||||
from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialChain
|
||||
from chatchat.server.utils import get_tool_config
|
||||
from chatchat.server.pydantic_v1 import Field
|
||||
from .tools_registry import regist_tool, BaseToolOutput
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy import event
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain.chains import LLMChain
|
||||
|
||||
READ_ONLY_PROMPT_TEMPLATE="""You are a MySQL expert. The database is currently in read-only mode.
|
||||
Given an input question, determine if the related SQL can be executed in read-only mode.
|
||||
If the SQL can be executed normally, return Answer:'SQL can be executed normally'.
|
||||
If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'.
|
||||
Use the following format:
|
||||
|
||||
Answer: Final answer here
|
||||
|
||||
Question: {query}
|
||||
"""
|
||||
|
||||
# 定义一个拦截器函数来检查SQL语句,以支持read-only,可修改下面的write_operations,以匹配你使用的数据库写操作关键字
|
||||
def intercept_sql(conn, cursor, statement, parameters, context, executemany):
|
||||
# List of SQL keywords that indicate a write operation
|
||||
write_operations = ("insert", "update", "delete", "create", "drop", "alter", "truncate", "rename")
|
||||
# Check if the statement starts with any of the write operation keywords
|
||||
if any(statement.strip().lower().startswith(op) for op in write_operations):
|
||||
raise OperationalError("Database is read-only. Write operations are not allowed.", params=None, orig=None)
|
||||
|
||||
def query_database(query: str,
|
||||
config: dict):
|
||||
top_k = config["top_k"]
|
||||
return_intermediate_steps = config["return_intermediate_steps"]
|
||||
sqlalchemy_connect_str = config["sqlalchemy_connect_str"]
|
||||
read_only = config["read_only"]
|
||||
db = SQLDatabase.from_uri(sqlalchemy_connect_str)
|
||||
|
||||
from chatchat.server.api_server.chat_routes import global_model_name
|
||||
from chatchat.server.utils import get_ChatOpenAI
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=global_model_name,
|
||||
temperature=0,
|
||||
streaming=True,
|
||||
local_wrap=True,
|
||||
verbose=True
|
||||
)
|
||||
table_names=config["table_names"]
|
||||
table_comments=config["table_comments"]
|
||||
result = None
|
||||
|
||||
#如果发现大模型判断用什么表出现问题,尝试给langchain提供额外的表说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判
|
||||
#由于langchain固定了输入参数,所以只能通过query传递额外的表说明
|
||||
if table_comments:
|
||||
TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n"
|
||||
table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])
|
||||
query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n"
|
||||
|
||||
if read_only:
|
||||
# 在read_only下,先让大模型判断只读模式是否能满足需求,避免后续执行过程报错,返回友好提示。
|
||||
READ_ONLY_PROMPT = PromptTemplate(
|
||||
input_variables=["query"],
|
||||
template=READ_ONLY_PROMPT_TEMPLATE,
|
||||
)
|
||||
read_only_chain = LLMChain(
|
||||
prompt=READ_ONLY_PROMPT,
|
||||
llm=llm,
|
||||
)
|
||||
read_only_result = read_only_chain.invoke(query)
|
||||
if "SQL cannot be executed normally" in read_only_result["text"]:
|
||||
return "当前数据库为只读状态,无法满足您的需求!"
|
||||
|
||||
# 当然大模型不能保证完全判断准确,为防止大模型判断有误,再从拦截器层面拒绝写操作
|
||||
event.listen(db._engine, "before_cursor_execute", intercept_sql)
|
||||
|
||||
#如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain
|
||||
#这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源
|
||||
#如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断
|
||||
if len(table_names) > 0:
|
||||
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
|
||||
result = db_chain.invoke({"query":query,"table_names_to_use":table_names})
|
||||
else:
|
||||
#先预测会使用哪些表,然后再将问题和预测的表给大模型
|
||||
db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
|
||||
result = db_chain.invoke(query)
|
||||
|
||||
context = f"""查询结果:{result['result']}\n\n"""
|
||||
|
||||
intermediate_steps=result["intermediate_steps"]
|
||||
#如果存在intermediate_steps,且这个数组的长度大于2,则保留最后两个元素,因为前面几个步骤存在示例数据,容易引起误解
|
||||
if intermediate_steps:
|
||||
if len(intermediate_steps)>2:
|
||||
sql_detail=intermediate_steps[-2:-1][0]["input"]
|
||||
# sql_detail截取从SQLQuery到Answer:之间的内容
|
||||
sql_detail=sql_detail[sql_detail.find("SQLQuery:")+9:sql_detail.find("Answer:")]
|
||||
context = context+"执行的sql:'"+sql_detail+"'\n\n"
|
||||
return context
|
||||
|
||||
|
||||
@regist_tool(title="Text2Sql")
|
||||
def text2sql(query: str = Field(description="No need for SQL statements,just input the natural language that you want to chat with database")):
|
||||
'''Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result.'''
|
||||
tool_config = get_tool_config("text2sql")
|
||||
return BaseToolOutput(query_database(query=query, config=tool_config))
|
||||
@ -28,6 +28,8 @@ chat_router.post("/file_chat",
|
||||
summary="文件对话"
|
||||
)(file_chat)
|
||||
|
||||
#定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name
|
||||
global_model_name=None
|
||||
|
||||
@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
|
||||
async def chat_completions(
|
||||
@ -51,6 +53,8 @@ async def chat_completions(
|
||||
for key in list(extra):
|
||||
delattr(body, key)
|
||||
|
||||
global global_model_name
|
||||
global_model_name=body.model
|
||||
# check tools & tool_choice in request body
|
||||
if isinstance(body.tool_choice, str):
|
||||
if t := get_tool(body.tool_choice):
|
||||
|
||||
@ -63,10 +63,10 @@ def upload_temp_docs(
|
||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
"""
|
||||
将文件保存到临时目录,并进行向量化。
|
||||
返回临时目录名称作为ID,同时也是临时向量库的ID。
|
||||
'''
|
||||
"""
|
||||
if prev_id is not None:
|
||||
memo_faiss_pool.pop(prev_id)
|
||||
|
||||
@ -134,7 +134,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||
docs = [x[0] for x in docs]
|
||||
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
|
||||
if len(docs) == 0: # 如果没有找到相关文档,使用Empty模板
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||
else:
|
||||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
|
||||
from chatchat.server.document_loaders.ocr import get_ocr
|
||||
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
|
||||
|
||||
|
||||
class RapidOCRLoader(UnstructuredFileLoader):
|
||||
@ -4,7 +4,7 @@ import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from chatchat.configs import PDF_OCR_THRESHOLD
|
||||
from chatchat.server.document_loaders.ocr import get_ocr
|
||||
from chatchat.server.file_rag.document_loaders.ocr import get_ocr
|
||||
import tqdm
|
||||
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||
from chatchat.server.file_rag.retrievers.vectorstore import VectorstoreRetrieverService
|
||||
from chatchat.server.file_rag.retrievers.ensemble import EnsembleRetrieverService
|
||||
@ -0,0 +1,24 @@
|
||||
from langchain.vectorstores import VectorStore
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseRetrieverService(metaclass=ABCMeta):
|
||||
def __init__(self, **kwargs):
|
||||
self.do_init(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def do_init(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def from_vectorstore(
|
||||
vectorstore: VectorStore,
|
||||
top_k: int,
|
||||
score_threshold: int or float,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_relevant_documents(self, query: str):
|
||||
pass
|
||||
@ -0,0 +1,47 @@
|
||||
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain.retrievers import EnsembleRetriever
|
||||
|
||||
|
||||
class EnsembleRetrieverService(BaseRetrieverService):
|
||||
def do_init(
|
||||
self,
|
||||
retriever: BaseRetriever = None,
|
||||
top_k: int = 5,
|
||||
):
|
||||
self.vs = None
|
||||
self.top_k = top_k
|
||||
self.retriever = retriever
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_vectorstore(
|
||||
vectorstore: VectorStore,
|
||||
top_k: int,
|
||||
score_threshold: int or float,
|
||||
):
|
||||
faiss_retriever = vectorstore.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={
|
||||
"score_threshold": score_threshold,
|
||||
"k": top_k
|
||||
}
|
||||
)
|
||||
# TODO: 换个不用torch的实现方式
|
||||
from cutword.cutword import Cutter
|
||||
cutter = Cutter()
|
||||
docs = list(vectorstore.docstore._dict.values())
|
||||
bm25_retriever = BM25Retriever.from_documents(
|
||||
docs,
|
||||
preprocess_func=cutter.cutword
|
||||
)
|
||||
bm25_retriever.k = top_k
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
|
||||
)
|
||||
return EnsembleRetrieverService(retriever=ensemble_retriever)
|
||||
|
||||
def get_relevant_documents(self, query: str):
|
||||
return self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||
@ -0,0 +1,33 @@
|
||||
from chatchat.server.file_rag.retrievers.base import BaseRetrieverService
|
||||
from langchain.vectorstores import VectorStore
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class VectorstoreRetrieverService(BaseRetrieverService):
|
||||
def do_init(
|
||||
self,
|
||||
retriever: BaseRetriever = None,
|
||||
top_k: int = 5,
|
||||
):
|
||||
self.vs = None
|
||||
self.top_k = top_k
|
||||
self.retriever = retriever
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_vectorstore(
|
||||
vectorstore: VectorStore,
|
||||
top_k: int,
|
||||
score_threshold: int or float,
|
||||
):
|
||||
retriever = vectorstore.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={
|
||||
"score_threshold": score_threshold,
|
||||
"k": top_k
|
||||
}
|
||||
)
|
||||
return VectorstoreRetrieverService(retriever=retriever)
|
||||
|
||||
def get_relevant_documents(self, query: str):
|
||||
return self.retriever.get_relevant_documents(query)[:self.top_k]
|
||||
13
libs/chatchat-server/chatchat/server/file_rag/utils.py
Normal file
13
libs/chatchat-server/chatchat/server/file_rag/utils.py
Normal file
@ -0,0 +1,13 @@
|
||||
from chatchat.server.file_rag.retrievers import (
|
||||
BaseRetrieverService,
|
||||
VectorstoreRetrieverService,
|
||||
EnsembleRetrieverService,
|
||||
)
|
||||
|
||||
Retrivals = {
|
||||
"vectorstore": VectorstoreRetrieverService,
|
||||
"ensemble": EnsembleRetrieverService,
|
||||
}
|
||||
|
||||
def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
|
||||
return Retrivals[type]
|
||||
@ -14,6 +14,7 @@ def list_kbs():
|
||||
|
||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"),
|
||||
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||
) -> BaseResponse:
|
||||
# Create selected knowledge base
|
||||
@ -26,7 +27,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
if kb is not None:
|
||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info)
|
||||
try:
|
||||
kb.create_kb()
|
||||
except Exception as e:
|
||||
|
||||
@ -62,7 +62,7 @@ class CachePool:
|
||||
self._cache_num = cache_num
|
||||
self._cache = OrderedDict()
|
||||
self.atomic = threading.RLock()
|
||||
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
return list(self._cache.keys())
|
||||
|
||||
|
||||
@ -95,6 +95,7 @@ class KBFaissPool(_FaissPool):
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
locked = True
|
||||
vector_name = vector_name or embed_model
|
||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||
try:
|
||||
@ -103,12 +104,14 @@ class KBFaissPool(_FaissPool):
|
||||
self.set((kb_name, vector_name), item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
locked = False
|
||||
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
||||
vs_path = get_vs_path(kb_name, vector_name)
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = get_Embeddings(embed_model=embed_model)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,
|
||||
allow_dangerous_deserialization=True)
|
||||
elif create:
|
||||
# create an empty vector store
|
||||
if not os.path.exists(vs_path):
|
||||
@ -121,8 +124,10 @@ class KBFaissPool(_FaissPool):
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
locked = False
|
||||
except Exception as e:
|
||||
self.atomic.release()
|
||||
if locked: # we don't know exception raised before or after atomic.release
|
||||
self.atomic.release()
|
||||
logger.error(e)
|
||||
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
|
||||
return self.get((kb_name, vector_name))
|
||||
|
||||
@ -1,21 +1,23 @@
|
||||
import json
|
||||
import os
|
||||
import urllib
|
||||
from typing import List, Dict
|
||||
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from langchain.docstore.document import Document
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||
logger, log_verbose, )
|
||||
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
||||
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
||||
files2docs_in_thread, KnowledgeFile)
|
||||
from fastapi.responses import FileResponse
|
||||
from sse_starlette import EventSourceResponse
|
||||
import json
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from langchain.docstore.document import Document
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details
|
||||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from typing import List, Dict
|
||||
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model
|
||||
|
||||
|
||||
def search_docs(
|
||||
@ -35,7 +37,8 @@ def search_docs(
|
||||
if kb is not None:
|
||||
if query:
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||
data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs]
|
||||
elif file_name or metadata:
|
||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||
for d in data:
|
||||
@ -55,8 +58,8 @@ def list_files(
|
||||
if kb is None:
|
||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
else:
|
||||
all_doc_names = kb.list_files()
|
||||
return ListResponse(data=all_doc_names)
|
||||
all_docs = get_kb_file_details(knowledge_base_name)
|
||||
return ListResponse(data=all_docs)
|
||||
|
||||
|
||||
def _save_files_in_thread(files: List[UploadFile],
|
||||
@ -352,38 +355,42 @@ def recreate_vector_store(
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
if kb.exists():
|
||||
kb.clear_vs()
|
||||
kb.create_kb()
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
for status, result in files2docs_in_thread(kb_files,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance):
|
||||
if status:
|
||||
kb_name, file_name, docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||
kb_file.splited_docs = docs
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
error_msg = f"could not recreate vector store because failed to access embed model."
|
||||
if not kb.check_embed_model(error_msg):
|
||||
yield {"code": 404, "msg": error_msg}
|
||||
else:
|
||||
if kb.exists():
|
||||
kb.clear_vs()
|
||||
kb.create_kb()
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
for status, result in files2docs_in_thread(kb_files,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance):
|
||||
if status:
|
||||
kb_name, file_name, docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||
kb_file.splited_docs = docs
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return EventSourceResponse(output())
|
||||
|
||||
@ -5,6 +5,11 @@ import os
|
||||
from pathlib import Path
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from typing import List, Union, Dict, Optional, Tuple
|
||||
|
||||
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
DEFAULT_EMBEDDING_MODEL, KB_INFO, logger)
|
||||
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
|
||||
from chatchat.server.db.repository.knowledge_base_repository import (
|
||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||
load_kb_from_db, get_kb_detail,
|
||||
@ -14,18 +19,12 @@ from chatchat.server.db.repository.knowledge_file_repository import (
|
||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
||||
list_docs_from_db,
|
||||
)
|
||||
|
||||
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
DEFAULT_EMBEDDING_MODEL, KB_INFO)
|
||||
from chatchat.server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, KnowledgeFile,
|
||||
list_kbs_from_folder, list_files_from_folder,
|
||||
)
|
||||
|
||||
from typing import List, Union, Dict, Optional, Tuple
|
||||
|
||||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
|
||||
from chatchat.server.utils import check_embed_model as _check_embed_model
|
||||
|
||||
class SupportedVSType:
|
||||
FAISS = 'faiss'
|
||||
@ -41,10 +40,11 @@ class KBService(ABC):
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
kb_info: str = None,
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||
self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||
self.embed_model = embed_model
|
||||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
@ -59,6 +59,13 @@ class KBService(ABC):
|
||||
'''
|
||||
pass
|
||||
|
||||
def check_embed_model(self, error_msg: str) -> bool:
|
||||
if not _check_embed_model(self.embed_model):
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def create_kb(self):
|
||||
"""
|
||||
创建知识库
|
||||
@ -93,6 +100,9 @@ class KBService(ABC):
|
||||
向知识库添加文件
|
||||
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if not self.check_embed_model(f"could not add docs because failed to access embed model."):
|
||||
return False
|
||||
|
||||
if docs:
|
||||
custom_docs = True
|
||||
else:
|
||||
@ -143,6 +153,9 @@ class KBService(ABC):
|
||||
使用content中的文件更新向量库
|
||||
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
|
||||
return False
|
||||
|
||||
if os.path.exists(kb_file.filepath):
|
||||
self.delete_doc(kb_file, **kwargs)
|
||||
return self.add_doc(kb_file, docs=docs, **kwargs)
|
||||
@ -162,6 +175,8 @@ class KBService(ABC):
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
) ->List[Document]:
|
||||
if not self.check_embed_model(f"could not search docs because failed to access embed model."):
|
||||
return []
|
||||
docs = self.do_search(query, top_k, score_threshold)
|
||||
return docs
|
||||
|
||||
@ -176,6 +191,9 @@ class KBService(ABC):
|
||||
传入参数为: {doc_id: Document, ...}
|
||||
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
||||
'''
|
||||
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
|
||||
return False
|
||||
|
||||
self.del_doc_by_ids(list(docs.keys()))
|
||||
docs = []
|
||||
ids = []
|
||||
@ -282,31 +300,32 @@ class KBServiceFactory:
|
||||
def get_service(kb_name: str,
|
||||
vector_store_type: Union[str, SupportedVSType],
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
kb_info: str = None,
|
||||
) -> KBService:
|
||||
if isinstance(vector_store_type, str):
|
||||
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||
params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info}
|
||||
if SupportedVSType.FAISS == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
return FaissKBService(kb_name, embed_model=embed_model)
|
||||
return FaissKBService(**params)
|
||||
elif SupportedVSType.PG == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||
return PGKBService(kb_name, embed_model=embed_model)
|
||||
return PGKBService(**params)
|
||||
elif SupportedVSType.MILVUS == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name, embed_model=embed_model)
|
||||
return MilvusKBService(**params)
|
||||
elif SupportedVSType.ZILLIZ == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
|
||||
return ZillizKBService(kb_name, embed_model=embed_model)
|
||||
return ZillizKBService(**params)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name,
|
||||
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||
return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config
|
||||
elif SupportedVSType.ES == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService
|
||||
return ESKBService(kb_name, embed_model=embed_model)
|
||||
return ESKBService(**params)
|
||||
elif SupportedVSType.CHROMADB == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
|
||||
return ChromaKBService(kb_name, embed_model=embed_model)
|
||||
return ChromaKBService(**params)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||
from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
return DefaultKBService(kb_name)
|
||||
|
||||
@ -9,6 +9,7 @@ from chatchat.configs import SCORE_THRESHOLD
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
|
||||
@ -75,10 +76,13 @@ class ChromaKBService(KBService):
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD) -> List[
|
||||
Tuple[Document, float]]:
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
|
||||
return _results_to_docs_and_scores(query_result)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.collection,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
doc_infos = []
|
||||
|
||||
@ -8,6 +8,7 @@ from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from elasticsearch import Elasticsearch, BadRequestError
|
||||
from chatchat.configs import kbs_config, KB_ROOT_PATH
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
import logging
|
||||
|
||||
@ -107,8 +108,12 @@ class ESKBService(KBService):
|
||||
|
||||
def do_search(self, query:str, top_k: int, score_threshold: float):
|
||||
# 文本相似性检索
|
||||
docs = self.db.similarity_search_with_score(query=query,
|
||||
k=top_k)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.db,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
|
||||
@ -5,9 +5,9 @@ from chatchat.configs import SCORE_THRESHOLD
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from chatchat.server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
from chatchat.server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from typing import List, Dict, Tuple
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
@ -62,10 +62,13 @@ class FaissKBService(KBService):
|
||||
top_k: int,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
retriever = get_Retriever("ensemble").from_vectorstore(
|
||||
vs,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
|
||||
@ -10,7 +10,7 @@ from chatchat.server.db.repository import list_file_num_docs_id_by_kb_name_and_f
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedVSType, \
|
||||
score_threshold_process
|
||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
class MilvusKBService(KBService):
|
||||
@ -67,10 +67,16 @@ class MilvusKBService(KBService):
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
self._load_milvus()
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
# embed_func = get_Embeddings(self.embed_model)
|
||||
# embeddings = embed_func.embed_query(query)
|
||||
# docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.milvus,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
for doc in docs:
|
||||
|
||||
@ -15,6 +15,7 @@ import shutil
|
||||
import sqlalchemy
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
class PGKBService(KBService):
|
||||
@ -60,10 +61,13 @@ class PGKBService(KBService):
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.pg_vector,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
ids = self.pg_vector.add_documents(docs)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import List, Dict, Optional
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from typing import List, Dict
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Zilliz
|
||||
from chatchat.configs import kbs_config
|
||||
@ -7,6 +6,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
|
||||
score_threshold_process
|
||||
from chatchat.server.knowledge_base.utils import KnowledgeFile
|
||||
from chatchat.server.utils import get_Embeddings
|
||||
from chatchat.server.file_rag.utils import get_Retriever
|
||||
|
||||
|
||||
class ZillizKBService(KBService):
|
||||
@ -60,10 +60,13 @@ class ZillizKBService(KBService):
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||
self._load_zilliz()
|
||||
embed_func = get_Embeddings(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
retriever = get_Retriever("vectorstore").from_vectorstore(
|
||||
self.zilliz,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(query)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
for doc in docs:
|
||||
|
||||
@ -42,55 +42,59 @@ def recreate_summary_vector_store(
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
# 重新创建知识库
|
||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||
kb_summary.drop_kb_summary()
|
||||
kb_summary.create_kb_summary()
|
||||
error_msg = f"could not recreate summary vector store because failed to access embed model."
|
||||
if not kb.check_embed_model(error_msg):
|
||||
yield {"code": 404, "msg": error_msg}
|
||||
else:
|
||||
# 重新创建知识库
|
||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||
kb_summary.drop_kb_summary()
|
||||
kb_summary.create_kb_summary()
|
||||
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
local_wrap=True,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
local_wrap=True,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
local_wrap=True,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
local_wrap=True,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
|
||||
i = 0
|
||||
for i, file_name in enumerate(files):
|
||||
i = 0
|
||||
for i, file_name in enumerate(files):
|
||||
|
||||
doc_infos = kb.list_docs(file_name=file_name)
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_infos)
|
||||
doc_infos = kb.list_docs(file_name=file_name)
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_infos)
|
||||
|
||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||
if status_kb_summary:
|
||||
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||
if status_kb_summary:
|
||||
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
|
||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
|
||||
return EventSourceResponse(output())
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from chatchat.configs import (
|
||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE
|
||||
CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose
|
||||
)
|
||||
from chatchat.server.knowledge_base.utils import (
|
||||
get_file_path, list_kbs_from_folder,
|
||||
|
||||
@ -10,7 +10,7 @@ from chatchat.configs import (
|
||||
TEXT_SPLITTER_NAME,
|
||||
)
|
||||
import importlib
|
||||
from chatchat.server.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||
from chatchat.server.file_rag.text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||
import langchain_community.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter, MarkdownHeaderTextSplitter
|
||||
|
||||
@ -244,6 +244,16 @@ def get_Embeddings(
|
||||
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
|
||||
|
||||
|
||||
def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool:
|
||||
embeddings = get_Embeddings(embed_model=embed_model)
|
||||
try:
|
||||
embeddings.embed_query("this is a test")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_OpenAIClient(
|
||||
platform_name: str = None,
|
||||
model_name: str = None,
|
||||
@ -696,7 +706,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
||||
'''
|
||||
创建一个临时目录,返回(路径,文件夹名称)
|
||||
'''
|
||||
from chatchat.configs.basic_config import BASE_TEMP_DIR
|
||||
from chatchat.configs import BASE_TEMP_DIR
|
||||
import uuid
|
||||
|
||||
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||
|
||||
@ -24,28 +24,28 @@ chat_box = ChatBox(
|
||||
|
||||
|
||||
def save_session():
|
||||
'''save session state to chat context'''
|
||||
"""save session state to chat context"""
|
||||
chat_box.context_from_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||
|
||||
|
||||
def restore_session():
|
||||
'''restore sesstion state from chat context'''
|
||||
"""restore sesstion state from chat context"""
|
||||
chat_box.context_to_session(exclude=["selected_page", "prompt", "cur_conv_name"])
|
||||
|
||||
|
||||
def rerun():
|
||||
'''
|
||||
"""
|
||||
save chat context before rerun
|
||||
'''
|
||||
"""
|
||||
save_session()
|
||||
st.rerun()
|
||||
|
||||
|
||||
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
|
||||
'''
|
||||
"""
|
||||
返回消息历史。
|
||||
content_in_expander控制是否返回expander元素中的内容,一般导出的时候可以选上,传入LLM的history不需要
|
||||
'''
|
||||
"""
|
||||
|
||||
def filter(msg):
|
||||
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
|
||||
@ -66,10 +66,10 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) ->
|
||||
|
||||
@st.cache_data
|
||||
def upload_temp_docs(files, _api: ApiRequest) -> str:
|
||||
'''
|
||||
"""
|
||||
将文件上传到临时目录,用于文件对话
|
||||
返回临时向量库ID
|
||||
'''
|
||||
"""
|
||||
return _api.upload_temp_docs(files).get("data", {}).get("id")
|
||||
|
||||
|
||||
@ -157,11 +157,13 @@ def dialogue_page(
|
||||
tools = list_tools(api)
|
||||
tool_names = ["None"] + list(tools)
|
||||
if use_agent:
|
||||
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", check_all=True, key="selected_tools")
|
||||
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
|
||||
# check_all=True, key="selected_tools")
|
||||
selected_tools = st.multiselect("选择工具", list(tools), format_func=lambda x: tools[x]["title"],
|
||||
key="selected_tools")
|
||||
else:
|
||||
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具", key="selected_tool")
|
||||
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
|
||||
# key="selected_tool")
|
||||
selected_tool = st.selectbox("选择工具", tool_names,
|
||||
format_func=lambda x: tools.get(x, {"title": "None"})["title"],
|
||||
key="selected_tool")
|
||||
@ -338,7 +340,7 @@ def dialogue_page(
|
||||
elif d.status == AgentStatus.agent_finish:
|
||||
text = d.choices[0].delta.content or ""
|
||||
chat_box.update_msg(text.replace("\n", "\n\n"))
|
||||
elif d.status == None: # not agent chat
|
||||
elif d.status is None: # not agent chat
|
||||
if getattr(d, "is_ref", False):
|
||||
chat_box.insert_msg(Markdown(d.choices[0].delta.content or "", in_expander=True, state="complete",
|
||||
title="参考资料"))
|
||||
|
||||
@ -10,16 +10,18 @@ packages = [
|
||||
|
||||
[tool.poetry.scripts]
|
||||
chatchat = 'chatchat.startup:main'
|
||||
chatchat-kb = 'chatchat.init_database:main'
|
||||
chatchat-config = 'chatchat.config_work_space:main'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<3.12,!=3.9.7"
|
||||
model-providers = "^0.3.0"
|
||||
langchain = "0.1.5"
|
||||
langchain = "0.1.17"
|
||||
langchainhub = "0.1.14"
|
||||
langchain-community = "0.0.17"
|
||||
langchain-community = "0.0.36"
|
||||
langchain-openai = "0.0.5"
|
||||
langchain-experimental = "0.0.50"
|
||||
fastapi = "0.109.2"
|
||||
langchain-experimental = "0.0.58"
|
||||
fastapi = "~0.109.2"
|
||||
sse_starlette = "~1.8.2"
|
||||
nltk = "~3.8.1"
|
||||
uvicorn = ">=0.27.0.post1"
|
||||
@ -27,6 +29,8 @@ unstructured = "~0.11.0"
|
||||
python-magic-bin = {version = "*", platform = "win32"}
|
||||
SQLAlchemy = "~2.0.25"
|
||||
faiss-cpu = "~1.7.4"
|
||||
cutword = "0.1.0"
|
||||
rank_bm25 = "0.2.2"
|
||||
# accelerate = "~0.24.1"
|
||||
# spacy = "~3.7.2"
|
||||
PyMuPDF = "~1.23.16"
|
||||
@ -34,28 +38,29 @@ rapidocr_onnxruntime = "~1.3.8"
|
||||
requests = "~2.31.0"
|
||||
pathlib = "~1.0.1"
|
||||
pytest = "~7.4.3"
|
||||
pyjwt = "2.8.0"
|
||||
pyjwt = "~2.8.0"
|
||||
elasticsearch = "*"
|
||||
numexpr = ">=1" #test
|
||||
strsimpy = ">=0.2.1"
|
||||
markdownify = ">=0.11.6"
|
||||
tqdm = ">=4.66.1"
|
||||
websockets = ">=12.0"
|
||||
numpy = "1.24.4"
|
||||
numpy = "~1.24.4"
|
||||
pandas = "~1" # test
|
||||
pydantic = "2.6.4"
|
||||
pydantic = "~2.6.4"
|
||||
httpx = {version = ">=0.25.2", extras = ["brotli", "http2", "socks"]}
|
||||
python-multipart = "0.0.9"
|
||||
# webui
|
||||
streamlit = "1.34.0"
|
||||
streamlit-option-menu = "0.3.12"
|
||||
streamlit-antd-components = "0.3.1"
|
||||
streamlit-chatbox = "1.1.12"
|
||||
streamlit-chatbox = "1.1.12.post2"
|
||||
streamlit-modal = "0.1.0"
|
||||
streamlit-aggrid = "0.3.4.post3"
|
||||
streamlit-extras = "0.4.2"
|
||||
xinference_client = { version = "^0.11.1", optional = true }
|
||||
zhipuai = { version = "^2.1.0", optional = true }
|
||||
pymysql = "^1.1.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
xinference = ["xinference_client"]
|
||||
@ -254,7 +259,7 @@ ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
|
||||
|
||||
[tool.poetry.plugins.dotenv]
|
||||
ignore = "false"
|
||||
location = ".env"
|
||||
dotenv = "dotenv:plugin"
|
||||
|
||||
|
||||
# https://python-poetry.org/docs/repositories/
|
||||
|
||||
90
libs/chatchat-server/tests/conftest.py
Normal file
90
libs/chatchat-server/tests/conftest.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Configuration for unit tests."""
|
||||
import logging
|
||||
from importlib import util
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
action="store_true",
|
||||
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--only-core",
|
||||
action="store_true",
|
||||
help="Only run core tests. Never runs any extended tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: Dict[str, bool] = {}
|
||||
|
||||
only_extended = config.getoption("--only-extended") or False
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
if only_core:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||
continue
|
||||
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
try:
|
||||
installed = util.find_spec(pkg) is not None
|
||||
except Exception:
|
||||
installed = False
|
||||
required_pkgs_info[pkg] = installed
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
if only_extended:
|
||||
pytest.fail(
|
||||
f"Package `{pkg}` is not installed but is required for "
|
||||
f"extended tests. Please install the given package and "
|
||||
f"try again.",
|
||||
)
|
||||
|
||||
else:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
||||
|
||||
|
||||
56
libs/chatchat-server/tests/unit_tests/config/test_config.py
Normal file
56
libs/chatchat-server/tests/unit_tests/config/test_config.py
Normal file
@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
|
||||
from chatchat.configs import ConfigBasicFactory, ConfigBasic, ConfigWorkSpace
|
||||
import os
|
||||
|
||||
|
||||
def test_config_factory_def():
|
||||
test_config_factory = ConfigBasicFactory()
|
||||
config: ConfigBasic = test_config_factory.get_config()
|
||||
assert config is not None
|
||||
assert config.log_verbose is False
|
||||
assert config.CHATCHAT_ROOT is not None
|
||||
assert config.DATA_PATH is not None
|
||||
assert config.IMG_DIR is not None
|
||||
assert config.NLTK_DATA_PATH is not None
|
||||
assert config.LOG_FORMAT is not None
|
||||
assert config.LOG_PATH is not None
|
||||
assert config.MEDIA_PATH is not None
|
||||
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
||||
|
||||
|
||||
def test_workspace():
|
||||
config_workspace = ConfigWorkSpace()
|
||||
assert config_workspace.get_config() is not None
|
||||
base_root = os.path.join(Path(__file__).absolute().parent, "chatchat")
|
||||
config_workspace.set_data_path(os.path.join(base_root, "data"))
|
||||
config_workspace.set_log_verbose(True)
|
||||
config_workspace.set_log_format(" %(message)s")
|
||||
|
||||
config: ConfigBasic = config_workspace.get_config()
|
||||
assert config.log_verbose is True
|
||||
assert config.DATA_PATH == os.path.join(base_root, "data")
|
||||
assert config.IMG_DIR is not None
|
||||
assert config.NLTK_DATA_PATH == os.path.join(base_root, "data", "nltk_data")
|
||||
assert config.LOG_FORMAT == " %(message)s"
|
||||
assert config.LOG_PATH == os.path.join(base_root, "data", "logs")
|
||||
assert config.MEDIA_PATH == os.path.join(base_root, "data", "media")
|
||||
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "image"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "audio"))
|
||||
assert os.path.exists(os.path.join(config.MEDIA_PATH, "video"))
|
||||
config_workspace.clear()
|
||||
|
||||
|
||||
def test_workspace_default():
|
||||
from chatchat.configs import (log_verbose, DATA_PATH, IMG_DIR, NLTK_DATA_PATH, LOG_FORMAT, LOG_PATH, MEDIA_PATH)
|
||||
assert log_verbose is False
|
||||
assert DATA_PATH is not None
|
||||
assert IMG_DIR is not None
|
||||
assert NLTK_DATA_PATH is not None
|
||||
assert LOG_FORMAT is not None
|
||||
assert LOG_PATH is not None
|
||||
assert MEDIA_PATH is not None
|
||||
1
libs/model-providers/.env
Normal file
1
libs/model-providers/.env
Normal file
@ -0,0 +1 @@
|
||||
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
|
||||
@ -1,26 +1,26 @@
|
||||
import redis
|
||||
from redis.connection import Connection, SSLConnection
|
||||
|
||||
redis_client = redis.Redis()
|
||||
|
||||
|
||||
def init_app(app):
|
||||
connection_class = Connection
|
||||
if app.config.get("REDIS_USE_SSL", False):
|
||||
connection_class = SSLConnection
|
||||
|
||||
redis_client.connection_pool = redis.ConnectionPool(
|
||||
**{
|
||||
"host": app.config.get("REDIS_HOST", "localhost"),
|
||||
"port": app.config.get("REDIS_PORT", 6379),
|
||||
"username": app.config.get("REDIS_USERNAME", None),
|
||||
"password": app.config.get("REDIS_PASSWORD", None),
|
||||
"db": app.config.get("REDIS_DB", 0),
|
||||
"encoding": "utf-8",
|
||||
"encoding_errors": "strict",
|
||||
"decode_responses": False,
|
||||
},
|
||||
connection_class=connection_class,
|
||||
)
|
||||
|
||||
app.extensions["redis"] = redis_client
|
||||
# import redis
|
||||
# from redis.connection import Connection, SSLConnection
|
||||
#
|
||||
# redis_client = redis.Redis()
|
||||
#
|
||||
#
|
||||
# def init_app(app):
|
||||
# connection_class = Connection
|
||||
# if app.config.get("REDIS_USE_SSL", False):
|
||||
# connection_class = SSLConnection
|
||||
#
|
||||
# redis_client.connection_pool = redis.ConnectionPool(
|
||||
# **{
|
||||
# "host": app.config.get("REDIS_HOST", "localhost"),
|
||||
# "port": app.config.get("REDIS_PORT", 6379),
|
||||
# "username": app.config.get("REDIS_USERNAME", None),
|
||||
# "password": app.config.get("REDIS_PASSWORD", None),
|
||||
# "db": app.config.get("REDIS_DB", 0),
|
||||
# "encoding": "utf-8",
|
||||
# "encoding_errors": "strict",
|
||||
# "decode_responses": False,
|
||||
# },
|
||||
# connection_class=connection_class,
|
||||
# )
|
||||
#
|
||||
# app.extensions["redis"] = redis_client
|
||||
|
||||
@ -7,20 +7,19 @@ readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<3.12,!=3.9.7"
|
||||
transformers = "4.31.0"
|
||||
fastapi = "^0.109.2"
|
||||
transformers = "~4.31.0"
|
||||
fastapi = "~0.109.2"
|
||||
uvicorn = ">=0.27.0.post1"
|
||||
sse-starlette = "^1.8.2"
|
||||
pyyaml = "6.0.1"
|
||||
pydantic = "2.6.4"
|
||||
redis = "4.5.4"
|
||||
sse-starlette = "~1.8.2"
|
||||
pyyaml = "~6.0.1"
|
||||
pydantic ="~2.6.4"
|
||||
# config manage
|
||||
omegaconf = "2.0.6"
|
||||
omegaconf = "~2.0.6"
|
||||
# modle_runtime
|
||||
openai = "1.13.3"
|
||||
tiktoken = "0.5.2"
|
||||
pydub = "0.25.1"
|
||||
boto3 = "1.28.17"
|
||||
openai = "~1.13.3"
|
||||
tiktoken = "~0.5.2"
|
||||
pydub = "~0.25.1"
|
||||
boto3 = "~1.28.17"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
# The only dependencies that should be added are
|
||||
@ -207,7 +206,7 @@ ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
|
||||
|
||||
[tool.poetry.plugins.dotenv]
|
||||
ignore = "false"
|
||||
location = ".env"
|
||||
dotenv = "dotenv:plugin"
|
||||
|
||||
# https://python-poetry.org/docs/repositories/
|
||||
[[tool.poetry.source]]
|
||||
|
||||
@ -25,7 +25,19 @@ langchain-chatchat = { path = "libs/chatchat-server", develop = true }
|
||||
ipykernel = "^6.29.2"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
|
||||
pytest = "^7.3.0"
|
||||
pytest-cov = "^4.0.0"
|
||||
pytest-dotenv = "^0.5.2"
|
||||
duckdb-engine = "^0.9.2"
|
||||
pytest-watcher = "^0.2.6"
|
||||
freezegun = "^1.2.2"
|
||||
responses = "^0.22.0"
|
||||
pytest-asyncio = "^0.23.2"
|
||||
lark = "^1.1.5"
|
||||
pytest-mock = "^3.10.0"
|
||||
pytest-socket = "^0.6.0"
|
||||
syrupy = "^4.0.2"
|
||||
requests-mock = "^1.11.0"
|
||||
|
||||
[tool.ruff]
|
||||
extend-include = ["*.ipynb"]
|
||||
@ -48,11 +60,11 @@ extend-exclude = [
|
||||
|
||||
[tool.poetry.plugins.dotenv]
|
||||
ignore = "false"
|
||||
location = ".env"
|
||||
dotenv = "dotenv:plugin"
|
||||
|
||||
|
||||
# https://python-poetry.org/docs/repositories/
|
||||
[[tool.poetry.source]]
|
||||
name = "tsinghua"
|
||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
|
||||
priority = "default"
|
||||
priority = "default"
|
||||
|
||||
@ -133,6 +133,8 @@ model_names = list(regs[model_type].keys())
|
||||
model_name = cols[1].selectbox("模型名称:", model_names)
|
||||
|
||||
cur_reg = regs[model_type][model_name]["reg"]
|
||||
model_format = None
|
||||
model_quant = None
|
||||
|
||||
if model_type == "LLM":
|
||||
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
|
||||
@ -217,7 +219,7 @@ cols = st.columns(3)
|
||||
|
||||
if cols[0].button("设置模型缓存"):
|
||||
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
||||
cur_spec.model_uri = local_path
|
||||
cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri
|
||||
if os.path.isdir(cache_dir):
|
||||
os.rmdir(cache_dir)
|
||||
if model_type == "LLM":
|
||||
@ -250,10 +252,10 @@ if cols[2].button("注册为自定义模型"):
|
||||
if model_type == "LLM":
|
||||
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
|
||||
cur_family.model_family = "other"
|
||||
model_definition = cur_family.json(indent=2, ensure_ascii=False)
|
||||
model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False)
|
||||
else:
|
||||
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
|
||||
model_definition = cur_spec.json(indent=2, ensure_ascii=False)
|
||||
model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False)
|
||||
client.register_model(
|
||||
model_type=model_type,
|
||||
model=model_definition,
|
||||
@ -262,4 +264,3 @@ if cols[2].button("注册为自定义模型"):
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("必须输入存在的绝对路径")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user