diff --git a/.gitignore b/.gitignore
index b5918eeb..a7ef90f8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,5 +3,7 @@
logs
.idea/
__pycache__/
-knowledge_base/
+/knowledge_base/
configs/*.py
+.vscode/
+.pytest_cache/
diff --git a/README.md b/README.md
index c5c2d687..577e7f7b 100644
--- a/README.md
+++ b/README.md
@@ -14,8 +14,8 @@
* [2. 下载模型至本地](README.md#2.-下载模型至本地)
* [3. 设置配置项](README.md#3.-设置配置项)
* [4. 知识库初始化与迁移](README.md#4.-知识库初始化与迁移)
- * [5. 启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI)
- * [6. 一键启动](README.md#6.-一键启动)
+ * [5. 一键启动API服务或WebUI服务](README.md#6.-一键启动)
+ * [6. 分步启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI)
* [常见问题](README.md#常见问题)
* [路线图](README.md#路线图)
* [项目交流群](README.md#项目交流群)
@@ -226,9 +226,93 @@ embedding_model_dict = {
$ python init_database.py --recreate-vs
```
-### 5. 启动 API 服务或 Web UI
+### 5. 一键启动API 服务或 Web UI
-#### 5.1 启动 LLM 服务
+#### 5.1 启动命令
+
+一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码:
+
+```shell
+$ python startup.py -a
+```
+
+并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。
+
+可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`,
+`-m (或--model-worker)`, `--api`, `--webui`,其中:
+
+- `--all-webui` 为一键启动 WebUI 所有依赖服务;
+- `--all-api` 为一键启动 API 所有依赖服务;
+- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务;
+- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务;
+- 其他为单独服务启动选项。
+
+#### 5.2 启动非默认模型
+
+若想指定非默认模型,需要用 `--model-name` 选项,示例:
+
+```shell
+$ python startup.py --all-webui --model-name Qwen-7B-Chat
+```
+
+更多信息可通过 `python startup.py -h`查看。
+
+#### 5.3 多卡加载
+
+项目支持多卡加载,需在 startup.py 中的 create_model_worker_app 函数中,修改如下三个参数:
+
+```python
+gpus=None,
+num_gpus=1,
+max_gpu_memory="20GiB"
+```
+
+其中,`gpus` 控制使用的显卡的ID,例如 "0,1";
+
+`num_gpus` 控制使用的卡数;
+
+`max_gpu_memory` 控制每个卡使用的显存容量。
+
+注1:server_config.py的FSCHAT_MODEL_WORKERS字典中也增加了相关配置,如有需要也可通过修改FSCHAT_MODEL_WORKERS字典中对应参数实现多卡加载。
+
+注2:少数情况下,gpus参数会不生效,此时需要通过设置环境变量CUDA_VISIBLE_DEVICES来指定torch可见的gpu,示例代码:
+
+```shell
+CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
+```
+
+#### 5.4 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等)
+
+本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含.bin 格式的 PEFT 权重,peft路径在startup.py中create_model_worker_app函数的args.model_names中指定,并开启环境变量PEFT_SHARE_BASE_WEIGHTS=true参数。
+
+注:如果上述方式启动失败,则需要以标准的fastchat服务启动方式分步启动,分步启动步骤参考第六节,PEFT加载详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822),
+
+#### **5.5 注意事项:**
+
+**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP:8501`)**
+
+**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。**
+
+**3. 在Linux上使用ctrl+C退出可能会由于linux的多进程机制导致multiprocessing遗留孤儿进程,可通过shutdown_all.sh进行退出**
+
+#### 5.6 启动界面示例:
+
+1. FastAPI docs 界面
+
+
+
+2. webui启动界面示例:
+
+- Web UI 对话界面:
+ 
+- Web UI 知识库管理页面:
+ 
+
+### 6 分步启动 API 服务或 Web UI
+
+注意:如使用了一键启动方式,可忽略本节。
+
+#### 6.1 启动 LLM 服务
如需使用开源模型进行本地部署,需首先启动 LLM 服务,启动方式分为三种:
@@ -240,7 +324,7 @@ embedding_model_dict = {
如果启动在线的API服务(如 OPENAI 的 API 接口),则无需启动 LLM 服务,即 5.1 小节的任何命令均无需启动。
-##### 5.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务
+##### 6.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务
在项目根目录下,执行 [server/llm_api.py](server/llm_api.py) 脚本启动 **LLM 模型**服务:
@@ -248,7 +332,7 @@ embedding_model_dict = {
$ python server/llm_api.py
```
-项目支持多卡加载,需在 llm_api.py 中修改 create_model_worker_app 函数中,修改如下三个参数:
+项目支持多卡加载,需在 llm_api.py 中的 create_model_worker_app 函数中,修改如下三个参数:
```python
gpus=None,
@@ -262,11 +346,11 @@ max_gpu_memory="20GiB"
`max_gpu_memory` 控制每个卡使用的显存容量。
-##### 5.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务
+##### 6.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务
⚠️ **注意:**
-**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;**
+**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wsl;**
**2.加载非默认模型需要用命令行参数--model-path-address指定模型,不会读取model_config.py配置;**
@@ -302,14 +386,14 @@ $ python server/llm_api_shutdown.py --serve all
亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`]
-##### 5.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia等)
+##### 6.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等)
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含 model.bin 格式的 PEFT 权重。
详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)

-#### 5.2 启动 API 服务
+#### 6.2 启动 API 服务
本地部署情况下,按照 [5.1 节](README.md#5.1-启动-LLM-服务)**启动 LLM 服务后**,再执行 [server/api.py](server/api.py) 脚本启动 **API** 服务;
@@ -327,7 +411,7 @@ $ python server/api.py

-#### 5.3 启动 Web UI 服务
+#### 6.3 启动 Web UI 服务
按照 [5.2 节](README.md#5.2-启动-API-服务)**启动 API 服务后**,执行 [webui.py](webui.py) 启动 **Web UI** 服务(默认使用端口 `8501`)
@@ -356,41 +440,6 @@ $ streamlit run webui.py --server.port 666
---
-### 6. 一键启动
-
-更新一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码:
-
-```shell
-$ python startup.py -a
-```
-
-并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。
-
-可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`,
-`-m (或--model-worker)`, `--api`, `--webui`,其中:
-
-- `--all-webui` 为一键启动 WebUI 所有依赖服务;
-- `--all-api` 为一键启动 API 所有依赖服务;
-- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务;
-- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务;
-- 其他为单独服务启动选项。
-
-若想指定非默认模型,需要用 `--model-name` 选项,示例:
-
-```shell
-$ python startup.py --all-webui --model-name Qwen-7B-Chat
-```
-
-更多信息可通过 `python startup.py -h`查看。
-
-**注意:**
-
-**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP:8501`)**
-
-**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。**
-
-**3. 在Linux上使用ctrl+C退出可能会由于linux的多进程机制导致multiprocessing遗留孤儿进程,可通过shutdown_all.sh进行退出**
-
## 常见问题
参见 [常见问题](docs/FAQ.md)。
@@ -433,6 +482,6 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat
## 项目交流群
-
+
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
diff --git a/configs/model_config.py.example b/configs/model_config.py.example
index 5b2574e9..1acaf1c1 100644
--- a/configs/model_config.py.example
+++ b/configs/model_config.py.example
@@ -1,6 +1,5 @@
import os
import logging
-import torch
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
@@ -32,9 +31,8 @@ embedding_model_dict = {
# 选用的 Embedding 名称
EMBEDDING_MODEL = "m3e-base"
-# Embedding 模型运行设备
-EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
-
+# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
+EMBEDDING_DEVICE = "auto"
llm_model_dict = {
"chatglm-6b": {
@@ -69,20 +67,32 @@ llm_model_dict = {
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.._completion_with_retry in
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
+ # 比如: "openai_proxy": 'http://127.0.0.1:4780'
"gpt-3.5-turbo": {
- "local_model_path": "gpt-3.5-turbo",
"api_base_url": "https://api.openai.com/v1",
"api_key": os.environ.get("OPENAI_API_KEY"),
"openai_proxy": os.environ.get("OPENAI_PROXY")
},
+ # 线上模型。当前支持智谱AI。
+ # 如果没有设置有效的local_model_path,则认为是在线模型API。
+ # 请在server_config中为每个在线API设置不同的端口
+ # 具体注册及api key获取请前往 http://open.bigmodel.cn
+ "chatglm-api": {
+ "api_base_url": "http://127.0.0.1:8888/v1",
+ "api_key": os.environ.get("ZHIPUAI_API_KEY"),
+ "provider": "ChatGLMWorker",
+ "version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
+ },
}
-
# LLM 名称
LLM_MODEL = "chatglm2-6b"
-# LLM 运行设备
-LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+# 历史对话轮数
+HISTORY_LEN = 3
+
+# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
+LLM_DEVICE = "auto"
# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
@@ -164,4 +174,4 @@ BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
-ZH_TITLE_ENHANCE = False
\ No newline at end of file
+ZH_TITLE_ENHANCE = False
diff --git a/configs/server_config.py.example b/configs/server_config.py.example
index b0f37bf4..ad731e37 100644
--- a/configs/server_config.py.example
+++ b/configs/server_config.py.example
@@ -1,4 +1,8 @@
-from .model_config import LLM_MODEL, LLM_DEVICE
+from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
+import httpx
+
+# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
+HTTPX_DEFAULT_TIMEOUT = 300.0
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
# is open cross domain
@@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = {
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = {
- LLM_MODEL: {
+ # 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。
+ "default": {
"host": DEFAULT_BIND_HOST,
"port": 20002,
"device": LLM_DEVICE,
- # todo: 多卡加载需要配置的参数
- "gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
- "num_gpus": 1, # 使用GPU的数量
- # 以下为非常用参数,可根据需要配置
+
+ # 多卡加载需要配置的参数
+ # "gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
+ # "num_gpus": 1, # 使用GPU的数量
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
+
+ # 以下为非常用参数,可根据需要配置
# "load_8bit": False, # 开启8bit量化
# "cpu_offloading": None,
# "gptq_ckpt": None,
@@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = {
# "stream_interval": 2,
# "no_register": False,
},
+ "baichuan-7b": { # 使用default中的IP和端口
+ "device": "cpu",
+ },
+ "chatglm-api": { # 请为每个在线API设置不同的端口
+ "port": 20003,
+ },
}
# fastchat multi model worker server
FSCHAT_MULTI_MODEL_WORKERS = {
- # todo
+ # TODO:
}
# fastchat controller server
@@ -66,35 +79,3 @@ FSCHAT_CONTROLLER = {
"port": 20001,
"dispatch_method": "shortest_queue",
}
-
-
-# 以下不要更改
-def fschat_controller_address() -> str:
- host = FSCHAT_CONTROLLER["host"]
- port = FSCHAT_CONTROLLER["port"]
- return f"http://{host}:{port}"
-
-
-def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
- if model := FSCHAT_MODEL_WORKERS.get(model_name):
- host = model["host"]
- port = model["port"]
- return f"http://{host}:{port}"
-
-
-def fschat_openai_api_address() -> str:
- host = FSCHAT_OPENAI_API["host"]
- port = FSCHAT_OPENAI_API["port"]
- return f"http://{host}:{port}"
-
-
-def api_address() -> str:
- host = API_SERVER["host"]
- port = API_SERVER["port"]
- return f"http://{host}:{port}"
-
-
-def webui_address() -> str:
- host = WEBUI_SERVER["host"]
- port = WEBUI_SERVER["port"]
- return f"http://{host}:{port}"
diff --git a/document_loaders/__init__.py b/document_loaders/__init__.py
new file mode 100644
index 00000000..a4d6b28d
--- /dev/null
+++ b/document_loaders/__init__.py
@@ -0,0 +1,2 @@
+from .mypdfloader import RapidOCRPDFLoader
+from .myimgloader import RapidOCRLoader
\ No newline at end of file
diff --git a/document_loaders/myimgloader.py b/document_loaders/myimgloader.py
new file mode 100644
index 00000000..86481924
--- /dev/null
+++ b/document_loaders/myimgloader.py
@@ -0,0 +1,25 @@
+from typing import List
+from langchain.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class RapidOCRLoader(UnstructuredFileLoader):
+ def _get_elements(self) -> List:
+ def img2text(filepath):
+ from rapidocr_onnxruntime import RapidOCR
+ resp = ""
+ ocr = RapidOCR()
+ result, _ = ocr(filepath)
+ if result:
+ ocr_result = [line[1] for line in result]
+ resp += "\n".join(ocr_result)
+ return resp
+
+ text = img2text(self.file_path)
+ from unstructured.partition.text import partition_text
+ return partition_text(text=text, **self.unstructured_kwargs)
+
+
+if __name__ == "__main__":
+ loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg")
+ docs = loader.load()
+ print(docs)
diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py
new file mode 100644
index 00000000..71e063d6
--- /dev/null
+++ b/document_loaders/mypdfloader.py
@@ -0,0 +1,37 @@
+from typing import List
+from langchain.document_loaders.unstructured import UnstructuredFileLoader
+
+
+class RapidOCRPDFLoader(UnstructuredFileLoader):
+ def _get_elements(self) -> List:
+ def pdf2text(filepath):
+ import fitz
+ from rapidocr_onnxruntime import RapidOCR
+ import numpy as np
+ ocr = RapidOCR()
+ doc = fitz.open(filepath)
+ resp = ""
+ for page in doc:
+ # TODO: 依据文本与图片顺序调整处理方式
+ text = page.get_text("")
+ resp += text + "\n"
+
+ img_list = page.get_images()
+ for img in img_list:
+ pix = fitz.Pixmap(doc, img[0])
+ img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
+ result, _ = ocr(img_array)
+ if result:
+ ocr_result = [line[1] for line in result]
+ resp += "\n".join(ocr_result)
+ return resp
+
+ text = pdf2text(self.file_path)
+ from unstructured.partition.text import partition_text
+ return partition_text(text=text, **self.unstructured_kwargs)
+
+
+if __name__ == "__main__":
+ loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf")
+ docs = loader.load()
+ print(docs)
diff --git a/img/qr_code_50.jpg b/img/qr_code_50.jpg
deleted file mode 100644
index c0ae20f6..00000000
Binary files a/img/qr_code_50.jpg and /dev/null differ
diff --git a/img/qr_code_51.jpg b/img/qr_code_51.jpg
deleted file mode 100644
index f0993322..00000000
Binary files a/img/qr_code_51.jpg and /dev/null differ
diff --git a/img/qr_code_52.jpg b/img/qr_code_52.jpg
deleted file mode 100644
index 18793d56..00000000
Binary files a/img/qr_code_52.jpg and /dev/null differ
diff --git a/img/qr_code_53.jpg b/img/qr_code_53.jpg
deleted file mode 100644
index 3174ccc1..00000000
Binary files a/img/qr_code_53.jpg and /dev/null differ
diff --git a/img/qr_code_54.jpg b/img/qr_code_54.jpg
deleted file mode 100644
index 1245a164..00000000
Binary files a/img/qr_code_54.jpg and /dev/null differ
diff --git a/img/qr_code_55.jpg b/img/qr_code_55.jpg
deleted file mode 100644
index 8ff046c9..00000000
Binary files a/img/qr_code_55.jpg and /dev/null differ
diff --git a/img/qr_code_56.jpg b/img/qr_code_56.jpg
deleted file mode 100644
index f17458d2..00000000
Binary files a/img/qr_code_56.jpg and /dev/null differ
diff --git a/img/qr_code_58.jpg b/img/qr_code_58.jpg
new file mode 100644
index 00000000..90c861d7
Binary files /dev/null and b/img/qr_code_58.jpg differ
diff --git a/init_database.py b/init_database.py
index 7fc84940..42a18c53 100644
--- a/init_database.py
+++ b/init_database.py
@@ -1,8 +1,9 @@
-from server.knowledge_base.migrate import create_tables, folder2db, recreate_all_vs, list_kbs_from_folder
+from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from startup import dump_server_info
+from datetime import datetime
if __name__ == "__main__":
@@ -25,13 +26,19 @@ if __name__ == "__main__":
dump_server_info()
- create_tables()
- print("database talbes created")
+ start_time = datetime.now()
if args.recreate_vs:
+ reset_tables()
+ print("database talbes reseted")
print("recreating all vector stores")
recreate_all_vs()
else:
+ create_tables()
+ print("database talbes created")
print("filling kb infos to database")
for kb in list_kbs_from_folder():
folder2db(kb, "fill_info_only")
+
+ end_time = datetime.now()
+ print(f"总计用时: {end_time-start_time}")
diff --git a/requirements.txt b/requirements.txt
index e40f665a..4271f3af 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -15,6 +15,8 @@ SQLAlchemy==2.0.19
faiss-cpu
accelerate
spacy
+PyMuPDF==1.22.5
+rapidocr_onnxruntime>=1.3.1
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
diff --git a/requirements_api.txt b/requirements_api.txt
index 58dbc0cd..bdecf3c7 100644
--- a/requirements_api.txt
+++ b/requirements_api.txt
@@ -16,6 +16,8 @@ faiss-cpu
nltk
accelerate
spacy
+PyMuPDF==1.22.5
+rapidocr_onnxruntime>=1.3.1
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
diff --git a/server/api.py b/server/api.py
index ecadd7cc..37954b7f 100644
--- a/server/api.py
+++ b/server/api.py
@@ -4,20 +4,22 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
-from configs.model_config import NLTK_DATA_PATH
-from configs.server_config import OPEN_CROSS_DOMAIN
+from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
+from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
from configs import VERSION
import argparse
import uvicorn
+from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
-from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
+from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
-from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
+from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
+import httpx
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@@ -84,11 +86,11 @@ def create_app():
summary="删除知识库"
)(delete_kb)
- app.get("/knowledge_base/list_docs",
+ app.get("/knowledge_base/list_files",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
- )(list_docs)
+ )(list_files)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
@@ -123,6 +125,75 @@ def create_app():
summary="根据content中文档重建向量库,流式输出处理进度。"
)(recreate_vector_store)
+ # LLM模型相关接口
+ @app.post("/llm_model/list_models",
+ tags=["LLM Model Management"],
+ summary="列出当前已加载的模型")
+ def list_models(
+ controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
+ ) -> BaseResponse:
+ '''
+ 从fastchat controller获取已加载模型列表
+ '''
+ try:
+ controller_address = controller_address or fschat_controller_address()
+ r = httpx.post(controller_address + "/list_models")
+ return BaseResponse(data=r.json()["models"])
+ except Exception as e:
+ return BaseResponse(
+ code=500,
+ data=[],
+ msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
+
+ @app.post("/llm_model/stop",
+ tags=["LLM Model Management"],
+ summary="停止指定的LLM模型(Model Worker)",
+ )
+ def stop_llm_model(
+ model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
+ controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
+ ) -> BaseResponse:
+ '''
+ 向fastchat controller请求停止某个LLM模型。
+ 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
+ '''
+ try:
+ controller_address = controller_address or fschat_controller_address()
+ r = httpx.post(
+ controller_address + "/release_worker",
+ json={"model_name": model_name},
+ )
+ return r.json()
+ except Exception as e:
+ return BaseResponse(
+ code=500,
+ msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
+
+ @app.post("/llm_model/change",
+ tags=["LLM Model Management"],
+ summary="切换指定的LLM模型(Model Worker)",
+ )
+ def change_llm_model(
+ model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
+ new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
+ controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
+ ):
+ '''
+ 向fastchat controller请求切换LLM模型。
+ '''
+ try:
+ controller_address = controller_address or fschat_controller_address()
+ r = httpx.post(
+ controller_address + "/release_worker",
+ json={"model_name": model_name, "new_model_name": new_model_name},
+ timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
+ )
+ return r.json()
+ except Exception as e:
+ return BaseResponse(
+ code=500,
+ msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
+
return app
diff --git a/server/chat/chat.py b/server/chat/chat.py
index 2e939f1d..ba23a5a1 100644
--- a/server/chat/chat.py
+++ b/server/chat/chat.py
@@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
+ model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
):
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
+ model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
@@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
streaming=True,
verbose=True,
callbacks=[callback],
- openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
- openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL,
- openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
+ openai_api_key=llm_model_dict[model_name]["api_key"],
+ openai_api_base=llm_model_dict[model_name]["api_base_url"],
+ model_name=model_name,
+ openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
@@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
await task
- return StreamingResponse(chat_iterator(query, history),
+ return StreamingResponse(chat_iterator(query, history, model_name),
media_type="text/event-stream")
diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py
index 27745691..69ec25dd 100644
--- a/server/chat/knowledge_base_chat.py
+++ b/server/chat/knowledge_base_chat.py
@@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
+ model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
):
@@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
kb: KBService,
top_k: int,
history: Optional[List[History]],
+ model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
- openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
- openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL,
- openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
+ openai_api_key=llm_model_dict[model_name]["api_key"],
+ openai_api_base=llm_model_dict[model_name]["api_base_url"],
+ model_name=model_name,
+ openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
@@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
await task
- return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
+ return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
media_type="text/event-stream")
diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py
index a7ad8074..a799c623 100644
--- a/server/chat/openai_chat.py
+++ b/server/chat/openai_chat.py
@@ -29,23 +29,22 @@ async def openai_chat(msg: OpenAiChatMsgIn):
print(f"{openai.api_base=}")
print(msg)
- async def get_response(msg):
+ def get_response(msg):
data = msg.dict()
- data["streaming"] = True
- data.pop("stream")
try:
response = openai.ChatCompletion.create(**data)
if msg.stream:
- for chunk in response.choices[0].message.content:
- print(chunk)
- yield chunk
+ for data in response:
+ if choices := data.choices:
+ if chunk := choices[0].get("delta", {}).get("content"):
+ print(chunk, end="", flush=True)
+ yield chunk
else:
- answer = ""
- for chunk in response.choices[0].message.content:
- answer += chunk
- print(answer)
- yield(answer)
+ if response.choices:
+ answer = response.choices[0].message.content
+ print(answer)
+ yield(answer)
except Exception as e:
print(type(e))
logger.error(e)
diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py
index 8a2633bd..8fe7dae6 100644
--- a/server/chat/search_engine_chat.py
+++ b/server/chat/search_engine_chat.py
@@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
+ model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
+ model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
- openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
- openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
- model_name=LLM_MODEL,
- openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
+ openai_api_key=llm_model_dict[model_name]["api_key"],
+ openai_api_base=llm_model_dict[model_name]["api_base_url"],
+ model_name=model_name,
+ openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = lookup_search_engine(query, search_engine_name, top_k)
@@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
ensure_ascii=False)
await task
- return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
+ return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
media_type="text/event-stream")
diff --git a/server/db/base.py b/server/db/base.py
index 3a8529b0..1d911c05 100644
--- a/server/db/base.py
+++ b/server/db/base.py
@@ -3,10 +3,14 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from configs.model_config import SQLALCHEMY_DATABASE_URI
+import json
-engine = create_engine(SQLALCHEMY_DATABASE_URI)
+
+engine = create_engine(
+ SQLALCHEMY_DATABASE_URI,
+ json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
+)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
-
diff --git a/server/db/models/knowledge_base_model.py b/server/db/models/knowledge_base_model.py
index 37abd4ee..478bc1f3 100644
--- a/server/db/models/knowledge_base_model.py
+++ b/server/db/models/knowledge_base_model.py
@@ -9,9 +9,9 @@ class KnowledgeBaseModel(Base):
"""
__tablename__ = 'knowledge_base'
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID')
- kb_name = Column(String, comment='知识库名称')
- vs_type = Column(String, comment='嵌入模型类型')
- embed_model = Column(String, comment='嵌入模型名称')
+ kb_name = Column(String(50), comment='知识库名称')
+ vs_type = Column(String(50), comment='向量库类型')
+ embed_model = Column(String(50), comment='嵌入模型名称')
file_count = Column(Integer, default=0, comment='文件数量')
create_time = Column(DateTime, default=func.now(), comment='创建时间')
diff --git a/server/db/models/knowledge_file_model.py b/server/db/models/knowledge_file_model.py
index 7fffdfbd..c5784d1a 100644
--- a/server/db/models/knowledge_file_model.py
+++ b/server/db/models/knowledge_file_model.py
@@ -1,4 +1,4 @@
-from sqlalchemy import Column, Integer, String, DateTime, func
+from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
from server.db.base import Base
@@ -9,13 +9,32 @@ class KnowledgeFileModel(Base):
"""
__tablename__ = 'knowledge_file'
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID')
- file_name = Column(String, comment='文件名')
- file_ext = Column(String, comment='文件扩展名')
- kb_name = Column(String, comment='所属知识库名称')
- document_loader_name = Column(String, comment='文档加载器名称')
- text_splitter_name = Column(String, comment='文本分割器名称')
+ file_name = Column(String(255), comment='文件名')
+ file_ext = Column(String(10), comment='文件扩展名')
+ kb_name = Column(String(50), comment='所属知识库名称')
+ document_loader_name = Column(String(50), comment='文档加载器名称')
+ text_splitter_name = Column(String(50), comment='文本分割器名称')
file_version = Column(Integer, default=1, comment='文件版本')
+ file_mtime = Column(Float, default=0.0, comment="文件修改时间")
+ file_size = Column(Integer, default=0, comment="文件大小")
+ custom_docs = Column(Boolean, default=False, comment="是否自定义docs")
+ docs_count = Column(Integer, default=0, comment="切分文档数量")
create_time = Column(DateTime, default=func.now(), comment='创建时间')
def __repr__(self):
return f""
+
+
+class FileDocModel(Base):
+ """
+ 文件-向量库文档模型
+ """
+ __tablename__ = 'file_doc'
+ id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
+ kb_name = Column(String(50), comment='知识库名称')
+ file_name = Column(String(255), comment='文件名称')
+ doc_id = Column(String(50), comment="向量库文档ID")
+ meta_data = Column(JSON, default={})
+
+ def __repr__(self):
+ return f""
diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py
index 404910ff..08417a4d 100644
--- a/server/db/repository/knowledge_file_repository.py
+++ b/server/db/repository/knowledge_file_repository.py
@@ -1,24 +1,101 @@
from server.db.models.knowledge_base_model import KnowledgeBaseModel
-from server.db.models.knowledge_file_model import KnowledgeFileModel
+from server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel
from server.db.session import with_session
from server.knowledge_base.utils import KnowledgeFile
+from typing import List, Dict
@with_session
-def list_docs_from_db(session, kb_name):
+def list_docs_from_db(session,
+ kb_name: str,
+ file_name: str = None,
+ metadata: Dict = {},
+ ) -> List[Dict]:
+ '''
+ 列出某知识库某文件对应的所有Document。
+ 返回形式:[{"id": str, "metadata": dict}, ...]
+ '''
+ docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
+ if file_name:
+ docs = docs.filter_by(file_name=file_name)
+ for k, v in metadata.items():
+ docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))
+
+ return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()]
+
+
+@with_session
+def delete_docs_from_db(session,
+ kb_name: str,
+ file_name: str = None,
+ ) -> List[Dict]:
+ '''
+ 删除某知识库某文件对应的所有Document,并返回被删除的Document。
+ 返回形式:[{"id": str, "metadata": dict}, ...]
+ '''
+ docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
+ query = session.query(FileDocModel).filter_by(kb_name=kb_name)
+ if file_name:
+ query = query.filter_by(file_name=file_name)
+ query.delete()
+ session.commit()
+ return docs
+
+
+@with_session
+def add_docs_to_db(session,
+ kb_name: str,
+ file_name: str,
+ doc_infos: List[Dict]):
+ '''
+ 将某知识库某文件对应的所有Document信息添加到数据库。
+ doc_infos形式:[{"id": str, "metadata": dict}, ...]
+ '''
+ for d in doc_infos:
+ obj = FileDocModel(
+ kb_name=kb_name,
+ file_name=file_name,
+ doc_id=d["id"],
+ meta_data=d["metadata"],
+ )
+ session.add(obj)
+ return True
+
+
+@with_session
+def count_files_from_db(session, kb_name: str) -> int:
+ return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count()
+
+
+@with_session
+def list_files_from_db(session, kb_name):
files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all()
docs = [f.file_name for f in files]
return docs
@with_session
-def add_doc_to_db(session, kb_file: KnowledgeFile):
+def add_file_to_db(session,
+ kb_file: KnowledgeFile,
+ docs_count: int = 0,
+ custom_docs: bool = False,
+ doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...]
+ ):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb:
- # 如果已经存在该文件,则更新文件版本号
- existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
- kb_name=kb_file.kb_name).first()
+ # 如果已经存在该文件,则更新文件信息与版本号
+ existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel)
+ .filter_by(file_name=kb_file.filename,
+ kb_name=kb_file.kb_name)
+ .first())
+ mtime = kb_file.get_mtime()
+ size = kb_file.get_size()
+
if existing_file:
+ existing_file.file_mtime = mtime
+ existing_file.file_size = size
+ existing_file.docs_count = docs_count
+ existing_file.custom_docs = custom_docs
existing_file.file_version += 1
# 否则,添加新文件
else:
@@ -28,9 +105,14 @@ def add_doc_to_db(session, kb_file: KnowledgeFile):
kb_name=kb_file.kb_name,
document_loader_name=kb_file.document_loader_name,
text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter",
+ file_mtime=mtime,
+ file_size=size,
+ docs_count = docs_count,
+ custom_docs=custom_docs,
)
kb.file_count += 1
session.add(new_file)
+ add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos)
return True
@@ -40,6 +122,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
kb_name=kb_file.kb_name).first()
if existing_file:
session.delete(existing_file)
+ delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
session.commit()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
@@ -52,7 +135,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
@with_session
def delete_files_from_db(session, knowledge_base_name: str):
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
-
+ session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
if kb:
kb.file_count = 0
@@ -62,7 +145,7 @@ def delete_files_from_db(session, knowledge_base_name: str):
@with_session
-def doc_exists(session, kb_file: KnowledgeFile):
+def file_exists_in_db(session, kb_file: KnowledgeFile):
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
return True if existing_file else False
@@ -82,6 +165,10 @@ def get_file_detail(session, kb_name: str, filename: str) -> dict:
"document_loader": file.document_loader_name,
"text_splitter": file.text_splitter_name,
"create_time": file.create_time,
+ "file_mtime": file.file_mtime,
+ "file_size": file.file_size,
+ "custom_docs": file.custom_docs,
+ "docs_count": file.docs_count,
}
else:
return {}
diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py
index ae027c12..7ea5d271 100644
--- a/server/knowledge_base/kb_doc_api.py
+++ b/server/knowledge_base/kb_doc_api.py
@@ -3,7 +3,7 @@ import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse
-from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
+from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
from fastapi.responses import StreamingResponse, FileResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
@@ -29,7 +29,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
return data
-async def list_docs(
+async def list_files(
knowledge_base_name: str
) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
@@ -40,7 +40,7 @@ async def list_docs(
if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
- all_doc_names = kb.list_docs()
+ all_doc_names = kb.list_files()
return ListResponse(data=all_doc_names)
@@ -183,14 +183,14 @@ async def recreate_vector_store(
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
- async def output():
+ def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
else:
kb.create_kb()
kb.clear_vs()
- docs = list_docs_from_folder(knowledge_base_name)
+ docs = list_files_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py
index 8d1de489..ca0919e0 100644
--- a/server/knowledge_base/kb_service/base.py
+++ b/server/knowledge_base/kb_service/base.py
@@ -1,25 +1,32 @@
+import operator
from abc import ABC, abstractmethod
import os
+import numpy as np
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
+from sklearn.preprocessing import normalize
+
from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail,
)
from server.db.repository.knowledge_file_repository import (
- add_doc_to_db, delete_file_from_db, delete_files_from_db, doc_exists,
- list_docs_from_db, get_file_detail, delete_file_from_db
+ add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
+ count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
+ list_docs_from_db,
)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
- EMBEDDING_DEVICE, EMBEDDING_MODEL)
+ EMBEDDING_MODEL)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
- list_kbs_from_folder, list_docs_from_folder,
+ list_kbs_from_folder, list_files_from_folder,
)
+from server.utils import embedding_device
from typing import List, Union, Dict
+from typing import List, Union, Dict, Optional
class SupportedVSType:
@@ -41,7 +48,7 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name)
self.do_init()
- def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
+ def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def create_kb(self):
@@ -62,7 +69,6 @@ class KBService(ABC):
status = delete_files_from_db(self.kb_name)
return status
-
def drop_kb(self):
"""
删除知识库
@@ -71,16 +77,24 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name)
return status
- def add_doc(self, kb_file: KnowledgeFile, **kwargs):
+ def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
向知识库添加文件
+ 如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
"""
- docs = kb_file.file2text()
+ if docs:
+ custom_docs = True
+ else:
+ docs = kb_file.file2text()
+ custom_docs = False
+
if docs:
self.delete_doc(kb_file)
- embeddings = self._load_embeddings()
- self.do_add_doc(docs, embeddings, **kwargs)
- status = add_doc_to_db(kb_file)
+ doc_infos = self.do_add_doc(docs, **kwargs)
+ status = add_file_to_db(kb_file,
+ custom_docs=custom_docs,
+ docs_count=len(docs),
+ doc_infos=doc_infos)
else:
status = False
return status
@@ -95,20 +109,24 @@ class KBService(ABC):
os.remove(kb_file.filepath)
return status
- def update_doc(self, kb_file: KnowledgeFile, **kwargs):
+ def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
使用content中的文件更新向量库
+ 如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs)
- return self.add_doc(kb_file, **kwargs)
-
+ return self.add_doc(kb_file, docs=docs, **kwargs)
+
def exist_doc(self, file_name: str):
- return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
+ return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name))
- def list_docs(self):
- return list_docs_from_db(self.kb_name)
+ def list_files(self):
+ return list_files_from_db(self.kb_name)
+
+ def count_files(self):
+ return count_files_from_db(self.kb_name)
def search_docs(self,
query: str,
@@ -119,6 +137,18 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs
+ # TODO: milvus/pg需要实现该方法
+ def get_doc_by_id(self, id: str) -> Optional[Document]:
+ return None
+
+ def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
+ '''
+ 通过file_name或metadata检索Document
+ '''
+ doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
+ docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
+ return docs
+
@abstractmethod
def do_create_kb(self):
"""
@@ -168,8 +198,7 @@ class KBService(ABC):
@abstractmethod
def do_add_doc(self,
docs: List[Document],
- embeddings: Embeddings,
- ):
+ ) -> List[Dict]:
"""
向知识库添加文档子类实自己逻辑
"""
@@ -208,8 +237,9 @@ class KBServiceFactory:
return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
- return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
- elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
+ return MilvusKBService(kb_name,
+ embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
+ elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)
@@ -217,7 +247,7 @@ class KBServiceFactory:
def get_service_by_name(kb_name: str
) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name)
- if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
+ if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
vs_type = "faiss"
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@@ -256,29 +286,30 @@ def get_kb_details() -> List[Dict]:
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
-
+
return data
-def get_kb_doc_details(kb_name: str) -> List[Dict]:
+def get_kb_file_details(kb_name: str) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(kb_name)
- docs_in_folder = list_docs_from_folder(kb_name)
- docs_in_db = kb.list_docs()
+ files_in_folder = list_files_from_folder(kb_name)
+ files_in_db = kb.list_files()
result = {}
- for doc in docs_in_folder:
+ for doc in files_in_folder:
result[doc] = {
"kb_name": kb_name,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
+ "docs_count": 0,
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
- for doc in docs_in_db:
+ for doc in files_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
@@ -292,5 +323,39 @@ def get_kb_doc_details(kb_name: str) -> List[Dict]:
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
-
+
return data
+
+
+class EmbeddingsFunAdapter(Embeddings):
+
+ def __init__(self, embeddings: Embeddings):
+ self.embeddings = embeddings
+
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
+ return normalize(self.embeddings.embed_documents(texts))
+
+ def embed_query(self, text: str) -> List[float]:
+ query_embed = self.embeddings.embed_query(text)
+ query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
+ normalized_query_embed = normalize(query_embed_2d)
+ return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
+
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+ return await normalize(self.embeddings.aembed_documents(texts))
+
+ async def aembed_query(self, text: str) -> List[float]:
+ return await normalize(self.embeddings.aembed_query(text))
+
+
+def score_threshold_process(score_threshold, k, docs):
+ if score_threshold is not None:
+ cmp = (
+ operator.le
+ )
+ docs = [
+ (doc, similarity)
+ for doc, similarity in docs
+ if cmp(similarity, score_threshold)
+ ]
+ return docs[:k]
diff --git a/server/knowledge_base/kb_service/default_kb_service.py b/server/knowledge_base/kb_service/default_kb_service.py
index 922e39be..a68d59c5 100644
--- a/server/knowledge_base/kb_service/default_kb_service.py
+++ b/server/knowledge_base/kb_service/default_kb_service.py
@@ -13,7 +13,7 @@ class DefaultKBService(KBService):
def do_drop_kb(self):
pass
- def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
+ def do_add_doc(self, docs: List[Document]):
pass
def do_clear_vs(self):
diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py
index b3f5439b..15cc790b 100644
--- a/server/knowledge_base/kb_service/faiss_kb_service.py
+++ b/server/knowledge_base/kb_service/faiss_kb_service.py
@@ -5,7 +5,6 @@ from configs.model_config import (
KB_ROOT_PATH,
CACHED_VS_NUM,
EMBEDDING_MODEL,
- EMBEDDING_DEVICE,
SCORE_THRESHOLD
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
@@ -13,40 +12,22 @@ from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
-from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings
-from langchain.embeddings.openai import OpenAIEmbeddings
-from typing import List
+from typing import List, Dict, Optional
from langchain.docstore.document import Document
-from server.utils import torch_gc
-
-
-# make HuggingFaceEmbeddings hashable
-def _embeddings_hash(self):
- if isinstance(self, HuggingFaceEmbeddings):
- return hash(self.model_name)
- elif isinstance(self, HuggingFaceBgeEmbeddings):
- return hash(self.model_name)
- elif isinstance(self, OpenAIEmbeddings):
- return hash(self.model)
-
-HuggingFaceEmbeddings.__hash__ = _embeddings_hash
-OpenAIEmbeddings.__hash__ = _embeddings_hash
-HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
-
-_VECTOR_STORE_TICKS = {}
+from server.utils import torch_gc, embedding_device
_VECTOR_STORE_TICKS = {}
@lru_cache(CACHED_VS_NUM)
-def load_vector_store(
+def load_faiss_vector_store(
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
- embed_device: str = EMBEDDING_DEVICE,
+ embed_device: str = embedding_device(),
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
-):
+) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
@@ -86,22 +67,39 @@ class FaissKBService(KBService):
def vs_type(self) -> str:
return SupportedVSType.FAISS
- @staticmethod
- def get_vs_path(knowledge_base_name: str):
- return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store")
+ def get_vs_path(self):
+ return os.path.join(self.get_kb_path(), "vector_store")
- @staticmethod
- def get_kb_path(knowledge_base_name: str):
- return os.path.join(KB_ROOT_PATH, knowledge_base_name)
+ def get_kb_path(self):
+ return os.path.join(KB_ROOT_PATH, self.kb_name)
+
+ def load_vector_store(self) -> FAISS:
+ return load_faiss_vector_store(
+ knowledge_base_name=self.kb_name,
+ embed_model=self.embed_model,
+ tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
+ )
+
+ def save_vector_store(self, vector_store: FAISS = None):
+ vector_store = vector_store or self.load_vector_store()
+ vector_store.save_local(self.vs_path)
+ return vector_store
+
+ def refresh_vs_cache(self):
+ refresh_vs_cache(self.kb_name)
+
+ def get_doc_by_id(self, id: str) -> Optional[Document]:
+ vector_store = self.load_vector_store()
+ return vector_store.docstore._dict.get(id)
def do_init(self):
- self.kb_path = FaissKBService.get_kb_path(self.kb_name)
- self.vs_path = FaissKBService.get_vs_path(self.kb_name)
+ self.kb_path = self.get_kb_path()
+ self.vs_path = self.get_vs_path()
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
- load_vector_store(self.kb_name)
+ self.load_vector_store()
def do_drop_kb(self):
self.clear_vs()
@@ -113,33 +111,27 @@ class FaissKBService(KBService):
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
- search_index = load_vector_store(self.kb_name,
- embeddings=embeddings,
- tick=_VECTOR_STORE_TICKS.get(self.kb_name))
+ search_index = self.load_vector_store()
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,
docs: List[Document],
- embeddings: Embeddings,
**kwargs,
- ):
- vector_store = load_vector_store(self.kb_name,
- embeddings=embeddings,
- tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
- vector_store.add_documents(docs)
+ ) -> List[Dict]:
+ vector_store = self.load_vector_store()
+ ids = vector_store.add_documents(docs)
+ doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
- refresh_vs_cache(self.kb_name)
+ self.refresh_vs_cache()
+ return doc_infos
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
- embeddings = self._load_embeddings()
- vector_store = load_vector_store(self.kb_name,
- embeddings=embeddings,
- tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
+ vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
@@ -148,14 +140,14 @@ class FaissKBService(KBService):
vector_store.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
- refresh_vs_cache(self.kb_name)
+ self.refresh_vs_cache()
- return True
+ return vector_store
def do_clear_vs(self):
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
- refresh_vs_cache(self.kb_name)
+ self.refresh_vs_cache()
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
@@ -166,3 +158,11 @@ class FaissKBService(KBService):
return "in_folder"
else:
return False
+
+
+if __name__ == '__main__':
+ faissService = FaissKBService("test")
+ faissService.add_doc(KnowledgeFile("README.md", "test"))
+ faissService.delete_doc(KnowledgeFile("README.md", "test"))
+ faissService.do_drop_kb()
+ print(faissService.search_docs("如何启动api服务"))
\ No newline at end of file
diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py
index f2c798cf..444765f6 100644
--- a/server/knowledge_base/kb_service/milvus_kb_service.py
+++ b/server/knowledge_base/kb_service/milvus_kb_service.py
@@ -1,12 +1,16 @@
-from typing import List
+from typing import List, Dict, Optional
+import numpy as np
+from faiss import normalize_L2
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Milvus
+from sklearn.preprocessing import normalize
from configs.model_config import SCORE_THRESHOLD, kbs_config
-from server.knowledge_base.kb_service.base import KBService, SupportedVSType
+from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
+ score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
@@ -18,6 +22,14 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
+ def get_doc_by_id(self, id: str) -> Optional[Document]:
+ if self.milvus.col:
+ data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
+ if len(data_list) > 0:
+ data = data_list[0]
+ text = data.pop("text")
+ return Document(page_content=text, metadata=data)
+
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
@@ -36,38 +48,31 @@ class MilvusKBService(KBService):
def _load_milvus(self, embeddings: Embeddings = None):
if embeddings is None:
embeddings = self._load_embeddings()
- self.milvus = Milvus(embedding_function=embeddings,
+ self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
def do_init(self):
self._load_milvus()
def do_drop_kb(self):
- self.milvus.col.drop()
+ if self.milvus.col:
+ self.milvus.col.drop()
- def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings):
- # todo: support score threshold
- self._load_milvus(embeddings=embeddings)
- return self.milvus.similarity_search_with_score(query, top_k)
+ def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
+ self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
+ return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
- def add_doc(self, kb_file: KnowledgeFile, **kwargs):
- """
- 向知识库添加文件
- """
- docs = kb_file.file2text()
- self.milvus.add_documents(docs)
- from server.db.repository.knowledge_file_repository import add_doc_to_db
- status = add_doc_to_db(kb_file)
- return status
-
- def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
- pass
+ def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
+ ids = self.milvus.add_documents(docs)
+ doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
+ return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
- filepath = kb_file.filepath.replace('\\', '\\\\')
- delete_list = [item.get("pk") for item in
- self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
- self.milvus.col.delete(expr=f'pk in {delete_list}')
+ if self.milvus.col:
+ filepath = kb_file.filepath.replace('\\', '\\\\')
+ delete_list = [item.get("pk") for item in
+ self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
+ self.milvus.col.delete(expr=f'pk in {delete_list}')
def do_clear_vs(self):
if self.milvus.col:
@@ -80,7 +85,9 @@ if __name__ == '__main__':
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("test")
- milvusService.add_doc(KnowledgeFile("README.md", "test"))
- milvusService.delete_doc(KnowledgeFile("README.md", "test"))
- milvusService.do_drop_kb()
- print(milvusService.search_docs("测试"))
+ # milvusService.add_doc(KnowledgeFile("README.md", "test"))
+
+ print(milvusService.get_doc_by_id("444022434274215486"))
+ # milvusService.delete_doc(KnowledgeFile("README.md", "test"))
+ # milvusService.do_drop_kb()
+ # print(milvusService.search_docs("如何启动api服务"))
diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py
index 6876bd86..e6381fac 100644
--- a/server/knowledge_base/kb_service/pg_kb_service.py
+++ b/server/knowledge_base/kb_service/pg_kb_service.py
@@ -1,25 +1,38 @@
-from typing import List
+import json
+from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import PGVector
+from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config
-from server.knowledge_base.kb_service.base import SupportedVSType, KBService
+
+from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
+ score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
+from server.utils import embedding_device as get_embedding_device
class PGKBService(KBService):
pg_vector: PGVector
- def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
+ def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
_embeddings = embeddings
if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device)
- self.pg_vector = PGVector(embedding_function=_embeddings,
+ self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
collection_name=self.kb_name,
+ distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
+ def get_doc_by_id(self, id: str) -> Optional[Document]:
+ with self.pg_vector.connect() as connect:
+ stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
+ results = [Document(page_content=row[0], metadata=row[1]) for row in
+ connect.execute(stmt, parameters={'id': id}).fetchall()]
+ if len(results) > 0:
+ return results[0]
def do_init(self):
self._load_pg_vector()
@@ -44,22 +57,14 @@ class PGKBService(KBService):
connect.commit()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
- # todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
- return self.pg_vector.similarity_search_with_score(query, top_k)
+ return score_threshold_process(score_threshold, top_k,
+ self.pg_vector.similarity_search_with_score(query, top_k))
- def add_doc(self, kb_file: KnowledgeFile, **kwargs):
- """
- 向知识库添加文件
- """
- docs = kb_file.file2text()
- self.pg_vector.add_documents(docs)
- from server.db.repository.knowledge_file_repository import add_doc_to_db
- status = add_doc_to_db(kb_file)
- return status
-
- def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
- pass
+ def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
+ ids = self.pg_vector.add_documents(docs)
+ doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
+ return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:
@@ -77,10 +82,11 @@ class PGKBService(KBService):
if __name__ == '__main__':
from server.db.base import Base, engine
- Base.metadata.create_all(bind=engine)
+ # Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test")
- pGKBService.create_kb()
- pGKBService.add_doc(KnowledgeFile("README.md", "test"))
- pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
- pGKBService.drop_kb()
- print(pGKBService.search_docs("测试"))
+ # pGKBService.create_kb()
+ # pGKBService.add_doc(KnowledgeFile("README.md", "test"))
+ # pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
+ # pGKBService.drop_kb()
+ print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772"))
+ # print(pGKBService.search_docs("如何启动api服务"))
diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py
index c96d3867..129dd53c 100644
--- a/server/knowledge_base/migrate.py
+++ b/server/knowledge_base/migrate.py
@@ -1,10 +1,17 @@
from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE
-from server.knowledge_base.utils import get_file_path, list_kbs_from_folder, list_docs_from_folder, KnowledgeFile
-from server.knowledge_base.kb_service.base import KBServiceFactory
-from server.db.repository.knowledge_file_repository import add_doc_to_db
+from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
+ list_files_from_folder, run_in_thread_pool,
+ files2docs_in_thread,
+ KnowledgeFile,)
+from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
+from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.base import Base, engine
import os
-from typing import Literal, Callable, Any
+from concurrent.futures import ThreadPoolExecutor
+from typing import Literal, Any, List
+
+
+pool = ThreadPoolExecutor(os.cpu_count())
def create_tables():
@@ -16,13 +23,22 @@ def reset_tables():
create_tables()
+def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
+ kb_files = []
+ for file in files:
+ try:
+ kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name)
+ kb_files.append(kb_file)
+ except Exception as e:
+ print(f"{e},已跳过")
+ return kb_files
+
+
def folder2db(
kb_name: str,
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
- callback_before: Callable = None,
- callback_after: Callable = None,
):
'''
use existed files in local folder to populate database and/or vector store.
@@ -36,70 +52,62 @@ def folder2db(
kb.create_kb()
if mode == "recreate_vs":
+ files_count = kb.count_files()
+ print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。")
kb.clear_vs()
- docs = list_docs_from_folder(kb_name)
- for i, doc in enumerate(docs):
- try:
- kb_file = KnowledgeFile(doc, kb_name)
- if callable(callback_before):
- callback_before(kb_file, i, docs)
- if i == len(docs) - 1:
- not_refresh_vs_cache = False
- else:
- not_refresh_vs_cache = True
- kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
- if callable(callback_after):
- callback_after(kb_file, i, docs)
- except Exception as e:
- print(e)
+ files_count = kb.count_files()
+ print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。")
+
+ kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
+ for success, result in files2docs_in_thread(kb_files, pool=pool):
+ if success:
+ _, filename, docs = result
+ print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
+ kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
+ kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
+ else:
+ print(result)
+
+ if kb.vs_type() == SupportedVSType.FAISS:
+ kb.save_vector_store()
+ kb.refresh_vs_cache()
elif mode == "fill_info_only":
- docs = list_docs_from_folder(kb_name)
- for i, doc in enumerate(docs):
- try:
- kb_file = KnowledgeFile(doc, kb_name)
- if callable(callback_before):
- callback_before(kb_file, i, docs)
- add_doc_to_db(kb_file)
- if callable(callback_after):
- callback_after(kb_file, i, docs)
- except Exception as e:
- print(e)
+ files = list_files_from_folder(kb_name)
+ kb_files = file_to_kbfile(kb_name, files)
+
+ for kb_file in kb_file:
+ add_file_to_db(kb_file)
+ print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
elif mode == "update_in_db":
- docs = kb.list_docs()
- for i, doc in enumerate(docs):
- try:
- kb_file = KnowledgeFile(doc, kb_name)
- if callable(callback_before):
- callback_before(kb_file, i, docs)
- if i == len(docs) - 1:
- not_refresh_vs_cache = False
- else:
- not_refresh_vs_cache = True
- kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
- if callable(callback_after):
- callback_after(kb_file, i, docs)
- except Exception as e:
- print(e)
+ files = kb.list_files()
+ kb_files = file_to_kbfile(kb_name, files)
+
+ for kb_file in kb_files:
+ kb.update_doc(kb_file, not_refresh_vs_cache=True)
+
+ if kb.vs_type() == SupportedVSType.FAISS:
+ kb.save_vector_store()
+ kb.refresh_vs_cache()
elif mode == "increament":
- db_docs = kb.list_docs()
- folder_docs = list_docs_from_folder(kb_name)
- docs = list(set(folder_docs) - set(db_docs))
- for i, doc in enumerate(docs):
- try:
- kb_file = KnowledgeFile(doc, kb_name)
- if callable(callback_before):
- callback_before(kb_file, i, docs)
- if i == len(docs) - 1:
- not_refresh_vs_cache = False
- else:
- not_refresh_vs_cache = True
- kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
- if callable(callback_after):
- callback_after(kb_file, i, docs)
- except Exception as e:
- print(e)
+ db_files = kb.list_files()
+ folder_files = list_files_from_folder(kb_name)
+ files = list(set(folder_files) - set(db_files))
+ kb_files = file_to_kbfile(kb_name, files)
+
+ for success, result in files2docs_in_thread(kb_files, pool=pool):
+ if success:
+ _, filename, docs = result
+ print(f"正在将 {kb_name}/{filename} 添加到向量库")
+ kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
+ kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
+ else:
+ print(result)
+
+ if kb.vs_type() == SupportedVSType.FAISS:
+ kb.save_vector_store()
+ kb.refresh_vs_cache()
else:
- raise ValueError(f"unspported migrate mode: {mode}")
+ print(f"unspported migrate mode: {mode}")
def recreate_all_vs(
@@ -114,30 +122,34 @@ def recreate_all_vs(
folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs)
-def prune_db_docs(kb_name: str):
+def prune_db_files(kb_name: str):
'''
- delete docs in database that not existed in local folder.
- it is used to delete database docs after user deleted some doc files in file browser
+ delete files in database that not existed in local folder.
+ it is used to delete database files after user deleted some doc files in file browser
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
- docs_in_db = kb.list_docs()
- docs_in_folder = list_docs_from_folder(kb_name)
- docs = list(set(docs_in_db) - set(docs_in_folder))
- for doc in docs:
- kb.delete_doc(KnowledgeFile(doc, kb_name))
- return docs
+ files_in_db = kb.list_files()
+ files_in_folder = list_files_from_folder(kb_name)
+ files = list(set(files_in_db) - set(files_in_folder))
+ kb_files = file_to_kbfile(kb_name, files)
+ for kb_file in kb_files:
+ kb.delete_doc(kb_file, not_refresh_vs_cache=True)
+ if kb.vs_type() == SupportedVSType.FAISS:
+ kb.save_vector_store()
+ kb.refresh_vs_cache()
+ return kb_files
-def prune_folder_docs(kb_name: str):
+def prune_folder_files(kb_name: str):
'''
delete doc files in local folder that not existed in database.
is is used to free local disk space by delete unused doc files.
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
- docs_in_db = kb.list_docs()
- docs_in_folder = list_docs_from_folder(kb_name)
- docs = list(set(docs_in_folder) - set(docs_in_db))
- for doc in docs:
- os.remove(get_file_path(kb_name, doc))
- return docs
+ files_in_db = kb.list_files()
+ files_in_folder = list_files_from_folder(kb_name)
+ files = list(set(files_in_folder) - set(files_in_db))
+ for file in files:
+ os.remove(get_file_path(kb_name, file))
+ return files
diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py
index 34f20832..a8a9bcca 100644
--- a/server/knowledge_base/utils.py
+++ b/server/knowledge_base/utils.py
@@ -12,6 +12,26 @@ from configs.model_config import (
from functools import lru_cache
import importlib
from text_splitter import zh_title_enhance
+import langchain.document_loaders
+from langchain.docstore.document import Document
+from pathlib import Path
+import json
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
+
+
+# make HuggingFaceEmbeddings hashable
+def _embeddings_hash(self):
+ if isinstance(self, HuggingFaceEmbeddings):
+ return hash(self.model_name)
+ elif isinstance(self, HuggingFaceBgeEmbeddings):
+ return hash(self.model_name)
+ elif isinstance(self, OpenAIEmbeddings):
+ return hash(self.model)
+
+HuggingFaceEmbeddings.__hash__ = _embeddings_hash
+OpenAIEmbeddings.__hash__ = _embeddings_hash
+HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
def validate_kb_name(knowledge_base_id: str) -> bool:
@@ -20,27 +40,34 @@ def validate_kb_name(knowledge_base_id: str) -> bool:
return False
return True
+
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
+
def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
+
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
+
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
+
def list_kbs_from_folder():
return [f for f in os.listdir(KB_ROOT_PATH)
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
-def list_docs_from_folder(kb_name: str):
+
+def list_files_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
+
@lru_cache(1)
def load_embeddings(model: str, device: str):
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
@@ -56,16 +83,90 @@ def load_embeddings(model: str, device: str):
return embeddings
-
-LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
- '.rtf', '.txt', '.xml',
- '.doc', '.docx', '.epub', '.odt', '.pdf',
- '.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
+LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
+ "UnstructuredMarkdownLoader": ['.md'],
+ "CustomJSONLoader": [".json"],
"CSVLoader": [".csv"],
- "PyPDFLoader": [".pdf"],
+ "RapidOCRPDFLoader": [".pdf"],
+ "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
+ "UnstructuredFileLoader": ['.eml', '.msg', '.rst',
+ '.rtf', '.txt', '.xml',
+ '.doc', '.docx', '.epub', '.odt',
+ '.ppt', '.pptx', '.tsv'], # '.xlsx'
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
+
+class CustomJSONLoader(langchain.document_loaders.JSONLoader):
+ '''
+ langchain的JSONLoader需要jq,在win上使用不便,进行替代。
+ '''
+
+ def __init__(
+ self,
+ file_path: Union[str, Path],
+ content_key: Optional[str] = None,
+ metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
+ text_content: bool = True,
+ json_lines: bool = False,
+ ):
+ """Initialize the JSONLoader.
+
+ Args:
+ file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
+ content_key (str): The key to use to extract the content from the JSON if
+ results to a list of objects (dict).
+ metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
+ object extracted by the jq_schema and the default metadata and returns
+ a dict of the updated metadata.
+ text_content (bool): Boolean flag to indicate whether the content is in
+ string format, default to True.
+ json_lines (bool): Boolean flag to indicate whether the input is in
+ JSON Lines format.
+ """
+ self.file_path = Path(file_path).resolve()
+ self._content_key = content_key
+ self._metadata_func = metadata_func
+ self._text_content = text_content
+ self._json_lines = json_lines
+
+ # TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
+ # This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
+ def load(self) -> List[Document]:
+ """Load and return documents from the JSON file."""
+ docs: List[Document] = []
+ if self._json_lines:
+ with self.file_path.open(encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ self._parse(line, docs)
+ else:
+ self._parse(self.file_path.read_text(encoding="utf-8"), docs)
+ return docs
+
+ def _parse(self, content: str, docs: List[Document]) -> None:
+ """Convert given content to documents."""
+ data = json.loads(content)
+
+ # Perform some validation
+ # This is not a perfect validation, but it should catch most cases
+ # and prevent the user from getting a cryptic error later on.
+ if self._content_key is not None:
+ self._validate_content_key(data)
+
+ for i, sample in enumerate(data, len(docs) + 1):
+ metadata = dict(
+ source=str(self.file_path),
+ seq_num=i,
+ )
+ text = self._get_text(sample=sample, metadata=metadata)
+ docs.append(Document(page_content=text, metadata=metadata))
+
+
+langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
+
+
def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER_DICT.items():
if file_extension in extensions:
@@ -90,10 +191,16 @@ class KnowledgeFile:
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = None
- def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
- print(self.document_loader_name)
+ def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
+ if self.docs is not None and not refresh:
+ return self.docs
+
+ print(f"{self.document_loader_name} used for {self.filepath}")
try:
- document_loaders_module = importlib.import_module('langchain.document_loaders')
+ if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
+ document_loaders_module = importlib.import_module('document_loaders')
+ else:
+ document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
except Exception as e:
print(e)
@@ -101,6 +208,16 @@ class KnowledgeFile:
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
+ elif self.document_loader_name == "CSVLoader":
+ loader = DocumentLoader(self.filepath, encoding="utf-8")
+ elif self.document_loader_name == "JSONLoader":
+ loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
+ elif self.document_loader_name == "CustomJSONLoader":
+ loader = DocumentLoader(self.filepath, text_content=False)
+ elif self.document_loader_name == "UnstructuredMarkdownLoader":
+ loader = DocumentLoader(self.filepath, mode="elements")
+ elif self.document_loader_name == "UnstructuredHTMLLoader":
+ loader = DocumentLoader(self.filepath, mode="elements")
else:
loader = DocumentLoader(self.filepath)
@@ -136,4 +253,63 @@ class KnowledgeFile:
print(docs[0])
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
+ self.docs = docs
return docs
+
+ def get_mtime(self):
+ return os.path.getmtime(self.filepath)
+
+ def get_size(self):
+ return os.path.getsize(self.filepath)
+
+
+def run_in_thread_pool(
+ func: Callable,
+ params: List[Dict] = [],
+ pool: ThreadPoolExecutor = None,
+) -> Generator:
+ '''
+ 在线程池中批量运行任务,并将运行结果以生成器的形式返回。
+ 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
+ '''
+ tasks = []
+ if pool is None:
+ pool = ThreadPoolExecutor()
+
+ for kwargs in params:
+ thread = pool.submit(func, **kwargs)
+ tasks.append(thread)
+
+ for obj in as_completed(tasks):
+ yield obj.result()
+
+
+def files2docs_in_thread(
+ files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
+ pool: ThreadPoolExecutor = None,
+) -> Generator:
+ '''
+ 利用多线程批量将文件转化成langchain Document.
+ 生成器返回值为{(kb_name, file_name): docs}
+ '''
+ def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
+ try:
+ return True, (file.kb_name, file.filename, file.file2text(**kwargs))
+ except Exception as e:
+ return False, e
+
+ kwargs_list = []
+ for i, file in enumerate(files):
+ kwargs = {}
+ if isinstance(file, tuple) and len(file) >= 2:
+ files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
+ elif isinstance(file, dict):
+ filename = file.pop("filename")
+ kb_name = file.pop("kb_name")
+ files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
+ kwargs = file
+ kwargs["file"] = file
+ kwargs_list.append(kwargs)
+
+ for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
+ yield result
diff --git a/server/llm_api.py b/server/llm_api.py
index ab71b3db..d9667e4f 100644
--- a/server/llm_api.py
+++ b/server/llm_api.py
@@ -4,8 +4,8 @@ import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
-from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
-from server.utils import MakeFastAPIOffline
+from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
+from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
host_ip = "0.0.0.0"
@@ -15,13 +15,6 @@ openai_api_port = 8888
base_url = "http://127.0.0.1:{}"
-def set_httpx_timeout(timeout=60.0):
- import httpx
- httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
- httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
- httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
-
-
def create_controller_app(
dispatch_method="shortest_queue",
):
@@ -41,7 +34,7 @@ def create_model_worker_app(
worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
- device=LLM_DEVICE,
+ device=llm_device(),
gpus=None,
max_gpu_memory="20GiB",
load_8bit=False,
diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py
new file mode 100644
index 00000000..932c2f3b
--- /dev/null
+++ b/server/model_workers/__init__.py
@@ -0,0 +1 @@
+from .zhipu import ChatGLMWorker
diff --git a/server/model_workers/base.py b/server/model_workers/base.py
new file mode 100644
index 00000000..b72f6839
--- /dev/null
+++ b/server/model_workers/base.py
@@ -0,0 +1,71 @@
+from configs.model_config import LOG_PATH
+import fastchat.constants
+fastchat.constants.LOGDIR = LOG_PATH
+from fastchat.serve.model_worker import BaseModelWorker
+import uuid
+import json
+import sys
+from pydantic import BaseModel
+import fastchat
+import threading
+from typing import Dict, List
+
+
+# 恢复被fastchat覆盖的标准输出
+sys.stdout = sys.__stdout__
+sys.stderr = sys.__stderr__
+
+
+class ApiModelOutMsg(BaseModel):
+ error_code: int = 0
+ text: str
+
+class ApiModelWorker(BaseModelWorker):
+ BASE_URL: str
+ SUPPORT_MODELS: List
+
+ def __init__(
+ self,
+ model_names: List[str],
+ controller_addr: str,
+ worker_addr: str,
+ context_len: int = 2048,
+ **kwargs,
+ ):
+ kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
+ kwargs.setdefault("model_path", "")
+ kwargs.setdefault("limit_worker_concurrency", 5)
+ super().__init__(model_names=model_names,
+ controller_addr=controller_addr,
+ worker_addr=worker_addr,
+ **kwargs)
+ self.context_len = context_len
+ self.init_heart_beat()
+
+ def count_token(self, params):
+ # TODO:需要完善
+ print("count token")
+ print(params)
+ prompt = params["prompt"]
+ return {"count": len(str(prompt)), "error_code": 0}
+
+ def generate_stream_gate(self, params):
+ self.call_ct += 1
+
+ def generate_gate(self, params):
+ for x in self.generate_stream_gate(params):
+ pass
+ return json.loads(x[:-1].decode())
+
+ def get_embeddings(self, params):
+ print("embedding")
+ print(params)
+
+ # workaround to make program exit with Ctrl+c
+ # it should be deleted after pr is merged by fastchat
+ def init_heart_beat(self):
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
+ )
+ self.heart_beat_thread.start()
diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py
new file mode 100644
index 00000000..4e4e15e0
--- /dev/null
+++ b/server/model_workers/zhipu.py
@@ -0,0 +1,75 @@
+import zhipuai
+from server.model_workers.base import ApiModelWorker
+from fastchat import conversation as conv
+import sys
+import json
+from typing import List, Literal
+
+
+class ChatGLMWorker(ApiModelWorker):
+ BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
+ SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
+
+ def __init__(
+ self,
+ *,
+ model_names: List[str] = ["chatglm-api"],
+ version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
+ controller_addr: str,
+ worker_addr: str,
+ **kwargs,
+ ):
+ kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
+ kwargs.setdefault("context_len", 32768)
+ super().__init__(**kwargs)
+ self.version = version
+
+ # 这里的是chatglm api的模板,其它API的conv_template需要定制
+ self.conv = conv.Conversation(
+ name="chatglm-api",
+ system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
+ messages=[],
+ roles=["Human", "Assistant"],
+ sep="\n### ",
+ stop_str="###",
+ )
+
+ def generate_stream_gate(self, params):
+ # TODO: 支持stream参数,维护request_id,传过来的prompt也有问题
+ from server.utils import get_model_worker_config
+
+ super().generate_stream_gate(params)
+ zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
+
+ response = zhipuai.model_api.sse_invoke(
+ model=self.version,
+ prompt=[{"role": "user", "content": params["prompt"]}],
+ temperature=params.get("temperature"),
+ top_p=params.get("top_p"),
+ incremental=False,
+ )
+ for e in response.events():
+ if e.event == "add":
+ yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
+ # TODO: 更健壮的消息处理
+ # elif e.event == "finish":
+ # ...
+
+ def get_embeddings(self, params):
+ # TODO: 支持embeddings
+ print("embedding")
+ print(params)
+
+
+if __name__ == "__main__":
+ import uvicorn
+ from server.utils import MakeFastAPIOffline
+ from fastchat.serve.model_worker import app
+
+ worker = ChatGLMWorker(
+ controller_addr="http://127.0.0.1:20001",
+ worker_addr="http://127.0.0.1:20003",
+ )
+ sys.modules["fastchat.serve.model_worker"].worker = worker
+ MakeFastAPIOffline(app)
+ uvicorn.run(app, port=20003)
diff --git a/server/utils.py b/server/utils.py
index 4a887225..0e53e3df 100644
--- a/server/utils.py
+++ b/server/utils.py
@@ -1,16 +1,20 @@
import pydantic
from pydantic import BaseModel
from typing import List
-import torch
from fastapi import FastAPI
from pathlib import Path
import asyncio
-from typing import Any, Optional
+from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE
+from configs.server_config import FSCHAT_MODEL_WORKERS
+import os
+from server import model_workers
+from typing import Literal, Optional, Any
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
+ data: Any = pydantic.Field(None, description="API data")
class Config:
schema_extra = {
@@ -68,6 +72,7 @@ class ChatMessage(BaseModel):
}
def torch_gc():
+ import torch
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
@@ -186,3 +191,117 @@ def MakeFastAPIOffline(
with_google_fonts=False,
redoc_favicon_url=favicon,
)
+
+
+# 从server_config中获取服务信息
+def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
+ '''
+ 加载model worker的配置项。
+ 优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
+ '''
+ from configs.server_config import FSCHAT_MODEL_WORKERS
+ from configs.model_config import llm_model_dict
+
+ config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
+ config.update(llm_model_dict.get(model_name, {}))
+ config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
+
+ # 如果没有设置有效的local_model_path,则认为是在线模型API
+ if not os.path.isdir(config.get("local_model_path", "")):
+ config["online_api"] = True
+ if provider := config.get("provider"):
+ try:
+ config["worker_class"] = getattr(model_workers, provider)
+ except Exception as e:
+ print(f"在线模型 ‘{model_name}’ 的provider没有正确配置")
+
+ config["device"] = llm_device(config.get("device") or LLM_DEVICE)
+ return config
+
+
+def get_all_model_worker_configs() -> dict:
+ result = {}
+ model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
+ for name in model_names:
+ if name != "default":
+ result[name] = get_model_worker_config(name)
+ return result
+
+
+def fschat_controller_address() -> str:
+ from configs.server_config import FSCHAT_CONTROLLER
+
+ host = FSCHAT_CONTROLLER["host"]
+ port = FSCHAT_CONTROLLER["port"]
+ return f"http://{host}:{port}"
+
+
+def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
+ if model := get_model_worker_config(model_name):
+ host = model["host"]
+ port = model["port"]
+ return f"http://{host}:{port}"
+ return ""
+
+
+def fschat_openai_api_address() -> str:
+ from configs.server_config import FSCHAT_OPENAI_API
+
+ host = FSCHAT_OPENAI_API["host"]
+ port = FSCHAT_OPENAI_API["port"]
+ return f"http://{host}:{port}"
+
+
+def api_address() -> str:
+ from configs.server_config import API_SERVER
+
+ host = API_SERVER["host"]
+ port = API_SERVER["port"]
+ return f"http://{host}:{port}"
+
+
+def webui_address() -> str:
+ from configs.server_config import WEBUI_SERVER
+
+ host = WEBUI_SERVER["host"]
+ port = WEBUI_SERVER["port"]
+ return f"http://{host}:{port}"
+
+
+def set_httpx_timeout(timeout: float = None):
+ '''
+ 设置httpx默认timeout。
+ httpx默认timeout是5秒,在请求LLM回答时不够用。
+ '''
+ import httpx
+ from configs.server_config import HTTPX_DEFAULT_TIMEOUT
+
+ timeout = timeout or HTTPX_DEFAULT_TIMEOUT
+ httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
+ httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
+ httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
+
+
+# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
+def detect_device() -> Literal["cuda", "mps", "cpu"]:
+ try:
+ import torch
+ if torch.cuda.is_available():
+ return "cuda"
+ if torch.backends.mps.is_available():
+ return "mps"
+ except:
+ pass
+ return "cpu"
+
+
+def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
+ if device not in ["cuda", "mps", "cpu"]:
+ device = detect_device()
+ return device
+
+
+def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
+ if device not in ["cuda", "mps", "cpu"]:
+ device = detect_device()
+ return device
diff --git a/startup.py b/startup.py
index c8706e28..591e3b10 100644
--- a/startup.py
+++ b/startup.py
@@ -1,59 +1,60 @@
-from multiprocessing import Process, Queue
+import asyncio
import multiprocessing as mp
+import os
import subprocess
import sys
-import os
+from multiprocessing import Process, Queue
from pprint import pprint
# 设置numexpr最大线程数,默认为CPU核心数
try:
import numexpr
+
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
-from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
+from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger
-from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
- FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,
- fschat_openai_api_address, )
-from server.utils import MakeFastAPIOffline, FastAPI
+from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
+ FSCHAT_OPENAI_API, )
+from server.utils import (fschat_controller_address, fschat_model_worker_address,
+ fschat_openai_api_address, set_httpx_timeout,
+ get_model_worker_config, get_all_model_worker_configs,
+ MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse
-from typing import Tuple, List
+from typing import Tuple, List, Dict
from configs import VERSION
-def set_httpx_timeout(timeout=60.0):
- import httpx
- httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
- httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
- httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
-
-
def create_controller_app(
dispatch_method: str,
+ log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
- from fastchat.serve.controller import app, Controller
+ from fastchat.serve.controller import app, Controller, logger
+ logger.setLevel(log_level)
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
+ app._controller = controller
return app
-def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
+def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
- from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
+ from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
import argparse
import threading
import fastchat.serve.model_worker
+ logger.setLevel(log_level)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
@@ -99,53 +100,76 @@ def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
- gptq_config = GptqConfig(
- ckpt=args.gptq_ckpt or args.model_path,
- wbits=args.gptq_wbits,
- groupsize=args.gptq_groupsize,
- act_order=args.gptq_act_order,
- )
- awq_config = AWQConfig(
- ckpt=args.awq_ckpt or args.model_path,
- wbits=args.awq_wbits,
- groupsize=args.awq_groupsize,
- )
+ # 在线模型API
+ if worker_class := kwargs.get("worker_class"):
+ worker = worker_class(model_names=args.model_names,
+ controller_addr=args.controller_address,
+ worker_addr=args.worker_address)
+ # 本地模型
+ else:
+ # workaround to make program exit with Ctrl+c
+ # it should be deleted after pr is merged by fastchat
+ def _new_init_heart_beat(self):
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
+ )
+ self.heart_beat_thread.start()
- worker = ModelWorker(
- controller_addr=args.controller_address,
- worker_addr=args.worker_address,
- worker_id=worker_id,
- model_path=args.model_path,
- model_names=args.model_names,
- limit_worker_concurrency=args.limit_worker_concurrency,
- no_register=args.no_register,
- device=args.device,
- num_gpus=args.num_gpus,
- max_gpu_memory=args.max_gpu_memory,
- load_8bit=args.load_8bit,
- cpu_offloading=args.cpu_offloading,
- gptq_config=gptq_config,
- awq_config=awq_config,
- stream_interval=args.stream_interval,
- conv_template=args.conv_template,
- )
+ ModelWorker.init_heart_beat = _new_init_heart_beat
+
+ gptq_config = GptqConfig(
+ ckpt=args.gptq_ckpt or args.model_path,
+ wbits=args.gptq_wbits,
+ groupsize=args.gptq_groupsize,
+ act_order=args.gptq_act_order,
+ )
+ awq_config = AWQConfig(
+ ckpt=args.awq_ckpt or args.model_path,
+ wbits=args.awq_wbits,
+ groupsize=args.awq_groupsize,
+ )
+
+ worker = ModelWorker(
+ controller_addr=args.controller_address,
+ worker_addr=args.worker_address,
+ worker_id=worker_id,
+ model_path=args.model_path,
+ model_names=args.model_names,
+ limit_worker_concurrency=args.limit_worker_concurrency,
+ no_register=args.no_register,
+ device=args.device,
+ num_gpus=args.num_gpus,
+ max_gpu_memory=args.max_gpu_memory,
+ load_8bit=args.load_8bit,
+ cpu_offloading=args.cpu_offloading,
+ gptq_config=gptq_config,
+ awq_config=awq_config,
+ stream_interval=args.stream_interval,
+ conv_template=args.conv_template,
+ )
+ sys.modules["fastchat.serve.model_worker"].args = args
+ sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].worker = worker
- sys.modules["fastchat.serve.model_worker"].args = args
- sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
- app.title = f"FastChat LLM Server ({LLM_MODEL})"
+ app.title = f"FastChat LLM Server ({args.model_names[0]})"
+ app._worker = worker
return app
def create_openai_api_app(
controller_address: str,
api_keys: List = [],
+ log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
+ from fastchat.utils import build_logger
+ logger = build_logger("openai_api", "openai_api.log")
+ logger.setLevel(log_level)
app.add_middleware(
CORSMiddleware,
@@ -155,6 +179,7 @@ def create_openai_api_app(
allow_headers=["*"],
)
+ sys.modules["fastchat.serve.openai_api_server"].logger = logger
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
@@ -164,6 +189,9 @@ def create_openai_api_app(
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
+ if q is None or not isinstance(run_seq, int):
+ return
+
if run_seq == 1:
@app.on_event("startup")
async def on_startup():
@@ -182,15 +210,90 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
q.put(run_seq)
-def run_controller(q: Queue, run_seq: int = 1):
+def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
import uvicorn
+ import httpx
+ from fastapi import Body
+ import time
+ import sys
- app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method"))
+ app = create_controller_app(
+ dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
+ log_level=log_level,
+ )
_set_app_seq(app, q, run_seq)
+ @app.on_event("startup")
+ def on_startup():
+ if e is not None:
+ e.set()
+
+ # add interface to release and load model worker
+ @app.post("/release_worker")
+ def release_worker(
+ model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
+ # worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
+ new_model_name: str = Body(None, description="释放后加载该模型"),
+ keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
+ ) -> Dict:
+ available_models = app._controller.list_models()
+ if new_model_name in available_models:
+ msg = f"要切换的LLM模型 {new_model_name} 已经存在"
+ logger.info(msg)
+ return {"code": 500, "msg": msg}
+
+ if new_model_name:
+ logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}")
+ else:
+ logger.info(f"即将停止LLM模型: {model_name}")
+
+ if model_name not in available_models:
+ msg = f"the model {model_name} is not available"
+ logger.error(msg)
+ return {"code": 500, "msg": msg}
+
+ worker_address = app._controller.get_worker_address(model_name)
+ if not worker_address:
+ msg = f"can not find model_worker address for {model_name}"
+ logger.error(msg)
+ return {"code": 500, "msg": msg}
+
+ r = httpx.post(worker_address + "/release",
+ json={"new_model_name": new_model_name, "keep_origin": keep_origin})
+ if r.status_code != 200:
+ msg = f"failed to release model: {model_name}"
+ logger.error(msg)
+ return {"code": 500, "msg": msg}
+
+ if new_model_name:
+ timer = 300 # wait 5 minutes for new model_worker register
+ while timer > 0:
+ models = app._controller.list_models()
+ if new_model_name in models:
+ break
+ time.sleep(1)
+ timer -= 1
+ if timer > 0:
+ msg = f"sucess change model from {model_name} to {new_model_name}"
+ logger.info(msg)
+ return {"code": 200, "msg": msg}
+ else:
+ msg = f"failed change model from {model_name} to {new_model_name}"
+ logger.error(msg)
+ return {"code": 500, "msg": msg}
+ else:
+ msg = f"sucess to release model: {model_name}"
+ logger.info(msg)
+ return {"code": 200, "msg": msg}
+
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
- uvicorn.run(app, host=host, port=port)
+
+ if log_level == "ERROR":
+ sys.stdout = sys.__stdout__
+ sys.stderr = sys.__stderr__
+
+ uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_model_worker(
@@ -198,33 +301,59 @@ def run_model_worker(
controller_address: str = "",
q: Queue = None,
run_seq: int = 2,
+ log_level: str = "INFO",
):
import uvicorn
+ from fastapi import Body
+ import sys
- kwargs = FSCHAT_MODEL_WORKERS[model_name].copy()
+ kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
port = kwargs.pop("port")
- model_path = llm_model_dict[model_name].get("local_model_path", "")
- kwargs["model_path"] = model_path
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
- kwargs["worker_address"] = fschat_model_worker_address()
+ kwargs["worker_address"] = fschat_model_worker_address(model_name)
+ model_path = kwargs.get("local_model_path", "")
+ kwargs["model_path"] = model_path
- app = create_model_worker_app(**kwargs)
+ app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq)
+ if log_level == "ERROR":
+ sys.stdout = sys.__stdout__
+ sys.stderr = sys.__stderr__
- uvicorn.run(app, host=host, port=port)
+ # add interface to release and load model
+ @app.post("/release")
+ def release_model(
+ new_model_name: str = Body(None, description="释放后加载该模型"),
+ keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
+ ) -> Dict:
+ if keep_origin:
+ if new_model_name:
+ q.put(["start", new_model_name])
+ else:
+ if new_model_name:
+ q.put(["replace", new_model_name])
+ else:
+ q.put(["stop"])
+ return {"code": 200, "msg": "done"}
+
+ uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
-def run_openai_api(q: Queue, run_seq: int = 3):
+def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
import uvicorn
+ import sys
controller_addr = fschat_controller_address()
- app = create_openai_api_app(controller_addr) # todo: not support keys yet.
+ app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
_set_app_seq(app, q, run_seq)
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
+ if log_level == "ERROR":
+ sys.stdout = sys.__stdout__
+ sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port)
@@ -244,13 +373,15 @@ def run_api_server(q: Queue, run_seq: int = 4):
def run_webui(q: Queue, run_seq: int = 5):
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
- while True:
- no = q.get()
- if no != run_seq - 1:
- q.put(no)
- else:
- break
- q.put(run_seq)
+
+ if q is not None and isinstance(run_seq, int):
+ while True:
+ no = q.get()
+ if no != run_seq - 1:
+ q.put(no)
+ else:
+ break
+ q.put(run_seq)
p = subprocess.Popen(["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port)])
@@ -313,6 +444,13 @@ def parse_args() -> argparse.ArgumentParser:
help="run api.py server",
dest="api",
)
+ parser.add_argument(
+ "-p",
+ "--api-worker",
+ action="store_true",
+ help="run online model api such as zhipuai",
+ dest="api_worker",
+ )
parser.add_argument(
"-w",
"--webui",
@@ -320,26 +458,38 @@ def parse_args() -> argparse.ArgumentParser:
help="run webui.py server",
dest="webui",
)
+ parser.add_argument(
+ "-q",
+ "--quiet",
+ action="store_true",
+ help="减少fastchat服务log信息",
+ dest="quiet",
+ )
args = parser.parse_args()
- return args
+ return args, parser
-def dump_server_info(after_start=False):
+def dump_server_info(after_start=False, args=None):
import platform
import langchain
import fastchat
- from configs.server_config import api_address, webui_address
+ from server.utils import api_address, webui_address
- print("\n\n")
+ print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本:{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
print("\n")
- print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}")
- pprint(llm_model_dict[LLM_MODEL])
- print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
+
+ model = LLM_MODEL
+ if args and args.model_name:
+ model = args.model_name
+ print(f"当前LLM模型:{model} @ {llm_device()}")
+ pprint(llm_model_dict[model])
+ print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
+
if after_start:
print("\n")
print(f"服务端运行信息:")
@@ -351,110 +501,231 @@ def dump_server_info(after_start=False):
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
- print("\n\n")
+ print("\n")
-if __name__ == "__main__":
+async def start_main_server():
import time
+ import signal
+
+ def handler(signalname):
+ """
+ Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
+ Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
+ """
+ def f(signal_received, frame):
+ raise KeyboardInterrupt(f"{signalname} received")
+ return f
+
+ # This will be inherited by the child process if it is forked (not spawned)
+ signal.signal(signal.SIGINT, handler("SIGINT"))
+ signal.signal(signal.SIGTERM, handler("SIGTERM"))
mp.set_start_method("spawn")
- queue = Queue()
- args = parse_args()
+ manager = mp.Manager()
+
+ queue = manager.Queue()
+ args, parser = parse_args()
+
if args.all_webui:
args.openai_api = True
args.model_worker = True
args.api = True
+ args.api_worker = True
args.webui = True
elif args.all_api:
args.openai_api = True
args.model_worker = True
args.api = True
+ args.api_worker = True
args.webui = False
elif args.llm_api:
args.openai_api = True
args.model_worker = True
+ args.api_worker = True
args.api = False
args.webui = False
- dump_server_info()
- logger.info(f"正在启动服务:")
- logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
+ dump_server_info(args=args)
- processes = {}
+ if len(sys.argv) > 1:
+ logger.info(f"正在启动服务:")
+ logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
+ processes = {"online-api": []}
+
+ def process_count():
+ return len(processes) + len(processes["online-api"]) - 1
+
+ if args.quiet:
+ log_level = "ERROR"
+ else:
+ log_level = "INFO"
+
+ controller_started = manager.Event()
if args.openai_api:
process = Process(
target=run_controller,
- name=f"controller({os.getpid()})",
- args=(queue, len(processes) + 1),
+ name=f"controller",
+ args=(queue, process_count() + 1, log_level, controller_started),
daemon=True,
)
- process.start()
+
processes["controller"] = process
process = Process(
target=run_openai_api,
- name=f"openai_api({os.getpid()})",
- args=(queue, len(processes) + 1),
+ name=f"openai_api",
+ args=(queue, process_count() + 1),
daemon=True,
)
- process.start()
processes["openai_api"] = process
if args.model_worker:
- process = Process(
- target=run_model_worker,
- name=f"model_worker({os.getpid()})",
- args=(args.model_name, args.controller_address, queue, len(processes) + 1),
- daemon=True,
- )
- process.start()
- processes["model_worker"] = process
+ config = get_model_worker_config(args.model_name)
+ if not config.get("online_api"):
+ process = Process(
+ target=run_model_worker,
+ name=f"model_worker - {args.model_name}",
+ args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
+ daemon=True,
+ )
+
+ processes["model_worker"] = process
+
+ if args.api_worker:
+ configs = get_all_model_worker_configs()
+ for model_name, config in configs.items():
+ if config.get("online_api") and config.get("worker_class"):
+ process = Process(
+ target=run_model_worker,
+ name=f"model_worker - {model_name}",
+ args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
+ daemon=True,
+ )
+
+ processes["online-api"].append(process)
if args.api:
process = Process(
target=run_api_server,
- name=f"API Server{os.getpid()})",
- args=(queue, len(processes) + 1),
+ name=f"API Server",
+ args=(queue, process_count() + 1),
daemon=True,
)
- process.start()
+
processes["api"] = process
if args.webui:
process = Process(
target=run_webui,
- name=f"WEBUI Server{os.getpid()})",
- args=(queue, len(processes) + 1),
+ name=f"WEBUI Server",
+ args=(queue, process_count() + 1),
daemon=True,
)
- process.start()
+
processes["webui"] = process
- try:
- # log infors
- while True:
- no = queue.get()
- if no == len(processes):
- time.sleep(0.5)
- dump_server_info(True)
- break
- else:
- queue.put(no)
+ if process_count() == 0:
+ parser.print_help()
+ else:
+ try:
+ # 保证任务收到SIGINT后,能够正常退出
+ if p:= processes.get("controller"):
+ p.start()
+ p.name = f"{p.name} ({p.pid})"
+ controller_started.wait()
- if model_worker_process := processes.get("model_worker"):
- model_worker_process.join()
- for name, process in processes.items():
- if name != "model_worker":
+ if p:= processes.get("openai_api"):
+ p.start()
+ p.name = f"{p.name} ({p.pid})"
+
+ if p:= processes.get("model_worker"):
+ p.start()
+ p.name = f"{p.name} ({p.pid})"
+
+ for p in processes.get("online-api", []):
+ p.start()
+ p.name = f"{p.name} ({p.pid})"
+
+ if p:= processes.get("api"):
+ p.start()
+ p.name = f"{p.name} ({p.pid})"
+
+ if p:= processes.get("webui"):
+ p.start()
+ p.name = f"{p.name} ({p.pid})"
+
+ while True:
+ no = queue.get()
+ if no == process_count():
+ time.sleep(0.5)
+ dump_server_info(after_start=True, args=args)
+ break
+ else:
+ queue.put(no)
+
+ if model_worker_process := processes.get("model_worker"):
+ model_worker_process.join()
+ for process in processes.get("online-api", []):
process.join()
- except:
- if model_worker_process := processes.get("model_worker"):
- model_worker_process.terminate()
- for name, process in processes.items():
- if name != "model_worker":
- process.terminate()
+ for name, process in processes.items():
+ if name not in ["model_worker", "online-api"]:
+ if isinstance(p, list):
+ for work_process in p:
+ work_process.join()
+ else:
+ process.join()
+ except Exception as e:
+ # if model_worker_process := processes.pop("model_worker", None):
+ # model_worker_process.terminate()
+ # for process in processes.pop("online-api", []):
+ # process.terminate()
+ # for process in processes.values():
+ #
+ # if isinstance(process, list):
+ # for work_process in process:
+ # work_process.terminate()
+ # else:
+ # process.terminate()
+ logger.error(e)
+ logger.warning("Caught KeyboardInterrupt! Setting stop event...")
+ finally:
+ # Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
+ # .is_alive() also implicitly joins the process (good practice in linux)
+ # while alive_procs := [p for p in processes.values() if p.is_alive()]:
+
+ for p in processes.values():
+ logger.warning("Sending SIGKILL to %s", p)
+ # Queues and other inter-process communication primitives can break when
+ # process is killed, but we don't care here
+
+ if isinstance(p, list):
+ for process in p:
+ process.kill()
+
+ else:
+ p.kill()
+
+ for p in processes.values():
+ logger.info("Process status: %s", p)
+
+if __name__ == "__main__":
+
+ if sys.version_info < (3, 10):
+ loop = asyncio.get_event_loop()
+ else:
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
+
+ asyncio.set_event_loop(loop)
+ # 同步调用协程代码
+ loop.run_until_complete(start_main_server())
+
# 服务启动后接口调用示例:
# import openai
diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py
index 5a8b97d2..51bbac19 100644
--- a/tests/api/test_kb_api.py
+++ b/tests/api/test_kb_api.py
@@ -1,4 +1,3 @@
-from doctest import testfile
import requests
import json
import sys
@@ -6,7 +5,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
-from configs.server_config import api_address
+from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path
@@ -112,7 +111,7 @@ def test_upload_doc(api="/knowledge_base/upload_doc"):
assert data["msg"] == f"成功上传文件 {name}"
-def test_list_docs(api="/knowledge_base/list_docs"):
+def test_list_files(api="/knowledge_base/list_files"):
url = api_base_url + api
print("\n获取知识库中文件列表:")
r = requests.get(url, params={"knowledge_base_name": kb})
diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py
new file mode 100644
index 00000000..f348fe74
--- /dev/null
+++ b/tests/api/test_llm_api.py
@@ -0,0 +1,74 @@
+import requests
+import json
+import sys
+from pathlib import Path
+
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
+from configs.model_config import LLM_MODEL, llm_model_dict
+
+from pprint import pprint
+import random
+
+
+def get_configured_models():
+ model_workers = list(FSCHAT_MODEL_WORKERS)
+ if "default" in model_workers:
+ model_workers.remove("default")
+
+ llm_dict = list(llm_model_dict)
+
+ return model_workers, llm_dict
+
+
+api_base_url = api_address()
+
+
+def get_running_models(api="/llm_model/list_models"):
+ url = api_base_url + api
+ r = requests.post(url)
+ if r.status_code == 200:
+ return r.json()["data"]
+ return []
+
+
+def test_running_models(api="/llm_model/list_models"):
+ url = api_base_url + api
+ r = requests.post(url)
+ assert r.status_code == 200
+ print("\n获取当前正在运行的模型列表:")
+ pprint(r.json())
+ assert isinstance(r.json()["data"], list)
+ assert len(r.json()["data"]) > 0
+
+
+# 不建议使用stop_model功能。按现在的实现,停止了就只能手动再启动
+# def test_stop_model(api="/llm_model/stop"):
+# url = api_base_url + api
+# r = requests.post(url, json={""})
+
+
+def test_change_model(api="/llm_model/change"):
+ url = api_base_url + api
+
+ running_models = get_running_models()
+ assert len(running_models) > 0
+
+ model_workers, llm_dict = get_configured_models()
+
+ availabel_new_models = set(model_workers) - set(running_models)
+ if len(availabel_new_models) == 0:
+ availabel_new_models = set(llm_dict) - set(running_models)
+ availabel_new_models = list(availabel_new_models)
+ assert len(availabel_new_models) > 0
+ print(availabel_new_models)
+
+ model_name = random.choice(running_models)
+ new_model_name = random.choice(availabel_new_models)
+ print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
+ r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
+ assert r.status_code == 200
+
+ running_models = get_running_models()
+ assert new_model_name in running_models
diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py
index ad9d3d89..4c2d5faf 100644
--- a/tests/api/test_stream_chat_api.py
+++ b/tests/api/test_stream_chat_api.py
@@ -5,7 +5,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.model_config import BING_SUBSCRIPTION_KEY
-from configs.server_config import API_SERVER, api_address
+from server.utils import api_address
from pprint import pprint
diff --git a/tests/document_loader/test_imgloader.py b/tests/document_loader/test_imgloader.py
new file mode 100644
index 00000000..8bba7da9
--- /dev/null
+++ b/tests/document_loader/test_imgloader.py
@@ -0,0 +1,21 @@
+import sys
+from pathlib import Path
+
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+from pprint import pprint
+
+test_files = {
+ "ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"),
+}
+
+def test_rapidocrpdfloader():
+ pdf_path = test_files["ocr_test.pdf"]
+ from document_loaders import RapidOCRPDFLoader
+
+ loader = RapidOCRPDFLoader(pdf_path)
+ docs = loader.load()
+ pprint(docs)
+ assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)
+
+
diff --git a/tests/document_loader/test_pdfloader.py b/tests/document_loader/test_pdfloader.py
new file mode 100644
index 00000000..92460cb4
--- /dev/null
+++ b/tests/document_loader/test_pdfloader.py
@@ -0,0 +1,21 @@
+import sys
+from pathlib import Path
+
+root_path = Path(__file__).parent.parent.parent
+sys.path.append(str(root_path))
+from pprint import pprint
+
+test_files = {
+ "ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"),
+}
+
+def test_rapidocrloader():
+ img_path = test_files["ocr_test.jpg"]
+ from document_loaders import RapidOCRLoader
+
+ loader = RapidOCRLoader(img_path)
+ docs = loader.load()
+ pprint(docs)
+ assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)
+
+
diff --git a/tests/kb_vector_db/__init__.py b/tests/kb_vector_db/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/kb_vector_db/test_faiss_kb.py b/tests/kb_vector_db/test_faiss_kb.py
new file mode 100644
index 00000000..9c329c8b
--- /dev/null
+++ b/tests/kb_vector_db/test_faiss_kb.py
@@ -0,0 +1,35 @@
+from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
+from server.knowledge_base.migrate import create_tables
+from server.knowledge_base.utils import KnowledgeFile
+
+kbService = FaissKBService("test")
+test_kb_name = "test"
+test_file_name = "README.md"
+testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
+search_content = "如何启动api服务"
+
+
+def test_init():
+ create_tables()
+
+
+def test_create_db():
+ assert kbService.create_kb()
+
+
+def test_add_doc():
+ assert kbService.add_doc(testKnowledgeFile)
+
+
+def test_search_db():
+ result = kbService.search_docs(search_content)
+ assert len(result) > 0
+def test_delete_doc():
+ assert kbService.delete_doc(testKnowledgeFile)
+
+
+
+
+
+def test_delete_db():
+ assert kbService.drop_kb()
diff --git a/tests/kb_vector_db/test_milvus_db.py b/tests/kb_vector_db/test_milvus_db.py
new file mode 100644
index 00000000..ed723806
--- /dev/null
+++ b/tests/kb_vector_db/test_milvus_db.py
@@ -0,0 +1,31 @@
+from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
+from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
+from server.knowledge_base.kb_service.pg_kb_service import PGKBService
+from server.knowledge_base.migrate import create_tables
+from server.knowledge_base.utils import KnowledgeFile
+
+kbService = MilvusKBService("test")
+
+test_kb_name = "test"
+test_file_name = "README.md"
+testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
+search_content = "如何启动api服务"
+
+def test_init():
+ create_tables()
+
+
+def test_create_db():
+ assert kbService.create_kb()
+
+
+def test_add_doc():
+ assert kbService.add_doc(testKnowledgeFile)
+
+
+def test_search_db():
+ result = kbService.search_docs(search_content)
+ assert len(result) > 0
+def test_delete_doc():
+ assert kbService.delete_doc(testKnowledgeFile)
+
diff --git a/tests/kb_vector_db/test_pg_db.py b/tests/kb_vector_db/test_pg_db.py
new file mode 100644
index 00000000..12448d05
--- /dev/null
+++ b/tests/kb_vector_db/test_pg_db.py
@@ -0,0 +1,31 @@
+from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
+from server.knowledge_base.kb_service.pg_kb_service import PGKBService
+from server.knowledge_base.migrate import create_tables
+from server.knowledge_base.utils import KnowledgeFile
+
+kbService = PGKBService("test")
+
+test_kb_name = "test"
+test_file_name = "README.md"
+testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
+search_content = "如何启动api服务"
+
+
+def test_init():
+ create_tables()
+
+
+def test_create_db():
+ assert kbService.create_kb()
+
+
+def test_add_doc():
+ assert kbService.add_doc(testKnowledgeFile)
+
+
+def test_search_db():
+ result = kbService.search_docs(search_content)
+ assert len(result) > 0
+def test_delete_doc():
+ assert kbService.delete_doc(testKnowledgeFile)
+
diff --git a/tests/samples/ocr_test.jpg b/tests/samples/ocr_test.jpg
new file mode 100644
index 00000000..70c199b7
Binary files /dev/null and b/tests/samples/ocr_test.jpg differ
diff --git a/tests/samples/ocr_test.pdf b/tests/samples/ocr_test.pdf
new file mode 100644
index 00000000..3a137ad1
Binary files /dev/null and b/tests/samples/ocr_test.pdf differ
diff --git a/webui.py b/webui.py
index 58fc0e39..0cda9ebc 100644
--- a/webui.py
+++ b/webui.py
@@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu
from webui_pages import *
import os
from configs import VERSION
+from server.utils import api_address
-api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
+
+api = ApiRequest(base_url=api_address())
if __name__ == "__main__":
st.set_page_config(
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index 04ece7da..25b885b3 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -1,10 +1,14 @@
import streamlit as st
+from configs.server_config import FSCHAT_MODEL_WORKERS
from webui_pages.utils import *
from streamlit_chatbox import *
from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
-from typing import List, Dict
import os
+from configs.model_config import llm_model_dict, LLM_MODEL
+from server.utils import get_model_worker_config
+from typing import List, Dict
+
chat_box = ChatBox(
assistant_avatar=os.path.join(
@@ -59,9 +63,39 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change,
key="dialogue_mode",
)
- history_len = st.number_input("历史对话轮数:", 0, 10, 3)
- # todo: support history len
+ def on_llm_change():
+ st.session_state["prev_llm_model"] = llm_model
+
+ def llm_model_format_func(x):
+ if x in running_models:
+ return f"{x} (Running)"
+ return x
+
+ running_models = api.list_running_models()
+ config_models = api.list_config_models()
+ for x in running_models:
+ if x in config_models:
+ config_models.remove(x)
+ llm_models = running_models + config_models
+ if "prev_llm_model" not in st.session_state:
+ index = llm_models.index(LLM_MODEL)
+ else:
+ index = 0
+ llm_model = st.selectbox("选择LLM模型:",
+ llm_models,
+ index,
+ format_func=llm_model_format_func,
+ on_change=on_llm_change,
+ # key="llm_model",
+ )
+ if (st.session_state.get("prev_llm_model") != llm_model
+ and not get_model_worker_config(llm_model).get("online_api")):
+ with st.spinner(f"正在加载模型: {llm_model}"):
+ r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
+ st.session_state["prev_llm_model"] = llm_model
+
+ history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
def on_kb_change():
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
@@ -75,7 +109,7 @@ def dialogue_page(api: ApiRequest):
on_change=on_kb_change,
key="selected_kb",
)
- kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
+ kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
@@ -87,13 +121,13 @@ def dialogue_page(api: ApiRequest):
options=search_engine_list,
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
)
- se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
+ se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
# Display chat messages from history on app rerun
chat_box.output_messages()
- chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter "
+ chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter "
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
history = get_messages_history(history_len)
@@ -101,7 +135,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
- r = api.chat_chat(prompt, history)
+ r = api.chat_chat(prompt, history=history, model=llm_model)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
@@ -116,7 +150,7 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"),
])
text = ""
- for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
+ for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
@@ -129,8 +163,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="网络搜索结果"),
])
text = ""
- for d in api.search_engine_chat(prompt, search_engine, se_top_k):
- if error_msg := check_error_msg(d): # check whether error occured
+ for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
+ if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
else:
text += d["answer"]
diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py
index 4351e956..0889ca54 100644
--- a/webui_pages/knowledge_base/knowledge_base.py
+++ b/webui_pages/knowledge_base/knowledge_base.py
@@ -4,7 +4,7 @@ from st_aggrid import AgGrid, JsCode
from st_aggrid.grid_options_builder import GridOptionsBuilder
import pandas as pd
from server.knowledge_base.utils import get_file_path, LOADER_DICT
-from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details
+from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple
from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE
import os
@@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
# 上传文件
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
- files = st.file_uploader("上传知识文件",
+ files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)
@@ -152,7 +152,7 @@ def knowledge_base_page(api: ApiRequest):
# 知识库详情
# st.info("请选择文件,点击按钮进行操作。")
- doc_details = pd.DataFrame(get_kb_doc_details(kb))
+ doc_details = pd.DataFrame(get_kb_file_details(kb))
if not len(doc_details):
st.info(f"知识库 `{kb}` 中暂无文件")
else:
@@ -160,7 +160,7 @@ def knowledge_base_page(api: ApiRequest):
st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作")
doc_details.drop(columns=["kb_name"], inplace=True)
doc_details = doc_details[[
- "No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db",
+ "No", "file_name", "document_loader", "docs_count", "in_folder", "in_db",
]]
# doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×")
# doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×")
@@ -172,7 +172,8 @@ def knowledge_base_page(api: ApiRequest):
# ("file_ext", "文档类型"): {},
# ("file_version", "文档版本"): {},
("document_loader", "文档加载器"): {},
- ("text_splitter", "分词器"): {},
+ ("docs_count", "文档数量"): {},
+ # ("text_splitter", "分词器"): {},
# ("create_time", "创建时间"): {},
("in_folder", "源文件"): {"cellRenderer": cell_renderer},
("in_db", "向量库"): {"cellRenderer": cell_renderer},
@@ -244,7 +245,6 @@ def knowledge_base_page(api: ApiRequest):
cols = st.columns(3)
- # todo: freezed
if cols[0].button(
"依据源文件重建向量库",
# help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
@@ -258,7 +258,7 @@ def knowledge_base_page(api: ApiRequest):
if msg := check_error_msg(d):
st.toast(msg)
else:
- empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
+ empty.progress(d["finished"] / d["total"], d["msg"])
st.experimental_rerun()
if cols[2].button(
diff --git a/webui_pages/utils.py b/webui_pages/utils.py
index c666d458..08511044 100644
--- a/webui_pages/utils.py
+++ b/webui_pages/utils.py
@@ -6,11 +6,14 @@ from configs.model_config import (
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
+ llm_model_dict,
+ HISTORY_LEN,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
logger,
)
+from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx
import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn
@@ -20,21 +23,12 @@ import json
from io import BytesIO
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail
-from server.utils import run_async, iter_over_async
+from server.utils import run_async, iter_over_async, set_httpx_timeout
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
-
-
-def set_httpx_timeout(timeout=60.0):
- '''
- 设置httpx默认timeout到60秒。
- httpx默认timeout是5秒,在请求LLM回答时不够用。
- '''
- httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
- httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
- httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
+from pprint import pprint
KB_ROOT_PATH = Path(KB_ROOT_PATH)
@@ -50,7 +44,7 @@ class ApiRequest:
def __init__(
self,
base_url: str = "http://127.0.0.1:7861",
- timeout: float = 60.0,
+ timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly
):
self.base_url = base_url
@@ -224,9 +218,17 @@ class ApiRequest:
try:
with response as r:
for chunk in r.iter_text(None):
- if as_json and chunk:
- yield json.loads(chunk)
- elif chunk.strip():
+ if not chunk: # fastchat api yield empty bytes on start and end
+ continue
+ if as_json:
+ try:
+ data = json.loads(chunk)
+ pprint(data, depth=1)
+ yield data
+ except Exception as e:
+ logger.error(f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。")
+ else:
+ print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
@@ -274,6 +276,9 @@ class ApiRequest:
return self._fastapi_stream2generator(response)
else:
data = msg.dict(exclude_unset=True, exclude_none=True)
+ print(f"received input message:")
+ pprint(data)
+
response = self.post(
"/chat/fastchat",
json=data,
@@ -286,6 +291,7 @@ class ApiRequest:
query: str,
history: List[Dict] = [],
stream: bool = True,
+ model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@@ -298,8 +304,12 @@ class ApiRequest:
"query": query,
"history": history,
"stream": stream,
+ "model_name": model,
}
+ print(f"received input message:")
+ pprint(data)
+
if no_remote_api:
from server.chat.chat import chat
response = chat(**data)
@@ -316,6 +326,7 @@ class ApiRequest:
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
+ model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@@ -331,9 +342,13 @@ class ApiRequest:
"score_threshold": score_threshold,
"history": history,
"stream": stream,
+ "model_name": model,
"local_doc_url": no_remote_api,
}
+ print(f"received input message:")
+ pprint(data)
+
if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(**data)
@@ -352,6 +367,7 @@ class ApiRequest:
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
stream: bool = True,
+ model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@@ -365,8 +381,12 @@ class ApiRequest:
"search_engine_name": search_engine_name,
"top_k": top_k,
"stream": stream,
+ "model_name": model,
}
+ print(f"received input message:")
+ pprint(data)
+
if no_remote_api:
from server.chat.search_engine_chat import search_engine_chat
response = search_engine_chat(**data)
@@ -473,18 +493,18 @@ class ApiRequest:
no_remote_api: bool = None,
):
'''
- 对应api.py/knowledge_base/list_docs接口
+ 对应api.py/knowledge_base/list_files接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
- from server.knowledge_base.kb_doc_api import list_docs
- response = run_async(list_docs(knowledge_base_name))
+ from server.knowledge_base.kb_doc_api import list_files
+ response = run_async(list_files(knowledge_base_name))
return response.data
else:
response = self.get(
- "/knowledge_base/list_docs",
+ "/knowledge_base/list_files",
params={"knowledge_base_name": knowledge_base_name}
)
data = self._check_httpx_json_response(response)
@@ -633,6 +653,84 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
+ def list_running_models(self, controller_address: str = None):
+ '''
+ 获取Fastchat中正运行的模型列表
+ '''
+ r = self.post(
+ "/llm_model/list_models",
+ )
+ return r.json().get("data", [])
+
+ def list_config_models(self):
+ '''
+ 获取configs中配置的模型列表
+ '''
+ return list(llm_model_dict.keys())
+
+ def stop_llm_model(
+ self,
+ model_name: str,
+ controller_address: str = None,
+ ):
+ '''
+ 停止某个LLM模型。
+ 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
+ '''
+ data = {
+ "model_name": model_name,
+ "controller_address": controller_address,
+ }
+ r = self.post(
+ "/llm_model/stop",
+ json=data,
+ )
+ return r.json()
+
+ def change_llm_model(
+ self,
+ model_name: str,
+ new_model_name: str,
+ controller_address: str = None,
+ ):
+ '''
+ 向fastchat controller请求切换LLM模型。
+ '''
+ if not model_name or not new_model_name:
+ return
+
+ if new_model_name == model_name:
+ return {
+ "code": 200,
+ "msg": "什么都不用做"
+ }
+
+ running_models = self.list_running_models()
+ if model_name not in running_models:
+ return {
+ "code": 500,
+ "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
+ }
+
+ config_models = self.list_config_models()
+ if new_model_name not in config_models:
+ return {
+ "code": 500,
+ "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
+ }
+
+ data = {
+ "model_name": model_name,
+ "new_model_name": new_model_name,
+ "controller_address": controller_address,
+ }
+ r = self.post(
+ "/llm_model/change",
+ json=data,
+ timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
+ )
+ return r.json()
+
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''