Merge branch 'dev' into pre-release
4
.gitignore
vendored
@ -3,5 +3,7 @@
|
||||
logs
|
||||
.idea/
|
||||
__pycache__/
|
||||
knowledge_base/
|
||||
/knowledge_base/
|
||||
configs/*.py
|
||||
.vscode/
|
||||
.pytest_cache/
|
||||
|
||||
143
README.md
@ -14,8 +14,8 @@
|
||||
* [2. 下载模型至本地](README.md#2.-下载模型至本地)
|
||||
* [3. 设置配置项](README.md#3.-设置配置项)
|
||||
* [4. 知识库初始化与迁移](README.md#4.-知识库初始化与迁移)
|
||||
* [5. 启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI)
|
||||
* [6. 一键启动](README.md#6.-一键启动)
|
||||
* [5. 一键启动API服务或WebUI服务](README.md#6.-一键启动)
|
||||
* [6. 分步启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI)
|
||||
* [常见问题](README.md#常见问题)
|
||||
* [路线图](README.md#路线图)
|
||||
* [项目交流群](README.md#项目交流群)
|
||||
@ -226,9 +226,93 @@ embedding_model_dict = {
|
||||
$ python init_database.py --recreate-vs
|
||||
```
|
||||
|
||||
### 5. 启动 API 服务或 Web UI
|
||||
### 5. 一键启动API 服务或 Web UI
|
||||
|
||||
#### 5.1 启动 LLM 服务
|
||||
#### 5.1 启动命令
|
||||
|
||||
一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码:
|
||||
|
||||
```shell
|
||||
$ python startup.py -a
|
||||
```
|
||||
|
||||
并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。
|
||||
|
||||
可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`,
|
||||
`-m (或--model-worker)`, `--api`, `--webui`,其中:
|
||||
|
||||
- `--all-webui` 为一键启动 WebUI 所有依赖服务;
|
||||
- `--all-api` 为一键启动 API 所有依赖服务;
|
||||
- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务;
|
||||
- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务;
|
||||
- 其他为单独服务启动选项。
|
||||
|
||||
#### 5.2 启动非默认模型
|
||||
|
||||
若想指定非默认模型,需要用 `--model-name` 选项,示例:
|
||||
|
||||
```shell
|
||||
$ python startup.py --all-webui --model-name Qwen-7B-Chat
|
||||
```
|
||||
|
||||
更多信息可通过 `python startup.py -h`查看。
|
||||
|
||||
#### 5.3 多卡加载
|
||||
|
||||
项目支持多卡加载,需在 startup.py 中的 create_model_worker_app 函数中,修改如下三个参数:
|
||||
|
||||
```python
|
||||
gpus=None,
|
||||
num_gpus=1,
|
||||
max_gpu_memory="20GiB"
|
||||
```
|
||||
|
||||
其中,`gpus` 控制使用的显卡的ID,例如 "0,1";
|
||||
|
||||
`num_gpus` 控制使用的卡数;
|
||||
|
||||
`max_gpu_memory` 控制每个卡使用的显存容量。
|
||||
|
||||
注1:server_config.py的FSCHAT_MODEL_WORKERS字典中也增加了相关配置,如有需要也可通过修改FSCHAT_MODEL_WORKERS字典中对应参数实现多卡加载。
|
||||
|
||||
注2:少数情况下,gpus参数会不生效,此时需要通过设置环境变量CUDA_VISIBLE_DEVICES来指定torch可见的gpu,示例代码:
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
|
||||
```
|
||||
|
||||
#### 5.4 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等)
|
||||
|
||||
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含.bin 格式的 PEFT 权重,peft路径在startup.py中create_model_worker_app函数的args.model_names中指定,并开启环境变量PEFT_SHARE_BASE_WEIGHTS=true参数。
|
||||
|
||||
注:如果上述方式启动失败,则需要以标准的fastchat服务启动方式分步启动,分步启动步骤参考第六节,PEFT加载详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822),
|
||||
|
||||
#### **5.5 注意事项:**
|
||||
|
||||
**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP:8501`)**
|
||||
|
||||
**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。**
|
||||
|
||||
**3. 在Linux上使用ctrl+C退出可能会由于linux的多进程机制导致multiprocessing遗留孤儿进程,可通过shutdown_all.sh进行退出**
|
||||
|
||||
#### 5.6 启动界面示例:
|
||||
|
||||
1. FastAPI docs 界面
|
||||
|
||||

|
||||
|
||||
2. webui启动界面示例:
|
||||
|
||||
- Web UI 对话界面:
|
||||

|
||||
- Web UI 知识库管理页面:
|
||||

|
||||
|
||||
### 6 分步启动 API 服务或 Web UI
|
||||
|
||||
注意:如使用了一键启动方式,可忽略本节。
|
||||
|
||||
#### 6.1 启动 LLM 服务
|
||||
|
||||
如需使用开源模型进行本地部署,需首先启动 LLM 服务,启动方式分为三种:
|
||||
|
||||
@ -240,7 +324,7 @@ embedding_model_dict = {
|
||||
|
||||
如果启动在线的API服务(如 OPENAI 的 API 接口),则无需启动 LLM 服务,即 5.1 小节的任何命令均无需启动。
|
||||
|
||||
##### 5.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务
|
||||
##### 6.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务
|
||||
|
||||
在项目根目录下,执行 [server/llm_api.py](server/llm_api.py) 脚本启动 **LLM 模型**服务:
|
||||
|
||||
@ -248,7 +332,7 @@ embedding_model_dict = {
|
||||
$ python server/llm_api.py
|
||||
```
|
||||
|
||||
项目支持多卡加载,需在 llm_api.py 中修改 create_model_worker_app 函数中,修改如下三个参数:
|
||||
项目支持多卡加载,需在 llm_api.py 中的 create_model_worker_app 函数中,修改如下三个参数:
|
||||
|
||||
```python
|
||||
gpus=None,
|
||||
@ -262,11 +346,11 @@ max_gpu_memory="20GiB"
|
||||
|
||||
`max_gpu_memory` 控制每个卡使用的显存容量。
|
||||
|
||||
##### 5.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务
|
||||
##### 6.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务
|
||||
|
||||
⚠️ **注意:**
|
||||
|
||||
**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;**
|
||||
**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wsl;**
|
||||
|
||||
**2.加载非默认模型需要用命令行参数--model-path-address指定模型,不会读取model_config.py配置;**
|
||||
|
||||
@ -302,14 +386,14 @@ $ python server/llm_api_shutdown.py --serve all
|
||||
|
||||
亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`]
|
||||
|
||||
##### 5.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia等)
|
||||
##### 6.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等)
|
||||
|
||||
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含 model.bin 格式的 PEFT 权重。
|
||||
详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)
|
||||
|
||||

|
||||
|
||||
#### 5.2 启动 API 服务
|
||||
#### 6.2 启动 API 服务
|
||||
|
||||
本地部署情况下,按照 [5.1 节](README.md#5.1-启动-LLM-服务)**启动 LLM 服务后**,再执行 [server/api.py](server/api.py) 脚本启动 **API** 服务;
|
||||
|
||||
@ -327,7 +411,7 @@ $ python server/api.py
|
||||
|
||||

|
||||
|
||||
#### 5.3 启动 Web UI 服务
|
||||
#### 6.3 启动 Web UI 服务
|
||||
|
||||
按照 [5.2 节](README.md#5.2-启动-API-服务)**启动 API 服务后**,执行 [webui.py](webui.py) 启动 **Web UI** 服务(默认使用端口 `8501`)
|
||||
|
||||
@ -356,41 +440,6 @@ $ streamlit run webui.py --server.port 666
|
||||
|
||||
---
|
||||
|
||||
### 6. 一键启动
|
||||
|
||||
更新一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码:
|
||||
|
||||
```shell
|
||||
$ python startup.py -a
|
||||
```
|
||||
|
||||
并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。
|
||||
|
||||
可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`,
|
||||
`-m (或--model-worker)`, `--api`, `--webui`,其中:
|
||||
|
||||
- `--all-webui` 为一键启动 WebUI 所有依赖服务;
|
||||
- `--all-api` 为一键启动 API 所有依赖服务;
|
||||
- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务;
|
||||
- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务;
|
||||
- 其他为单独服务启动选项。
|
||||
|
||||
若想指定非默认模型,需要用 `--model-name` 选项,示例:
|
||||
|
||||
```shell
|
||||
$ python startup.py --all-webui --model-name Qwen-7B-Chat
|
||||
```
|
||||
|
||||
更多信息可通过 `python startup.py -h`查看。
|
||||
|
||||
**注意:**
|
||||
|
||||
**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP:8501`)**
|
||||
|
||||
**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。**
|
||||
|
||||
**3. 在Linux上使用ctrl+C退出可能会由于linux的多进程机制导致multiprocessing遗留孤儿进程,可通过shutdown_all.sh进行退出**
|
||||
|
||||
## 常见问题
|
||||
|
||||
参见 [常见问题](docs/FAQ.md)。
|
||||
@ -433,6 +482,6 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat
|
||||
|
||||
## 项目交流群
|
||||
|
||||
<img src="img/qr_code_56.jpg" alt="二维码" width="300" height="300" />
|
||||
<img src="img/qr_code_58.jpg" alt="二维码" width="300" height="300" />
|
||||
|
||||
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
# 日志格式
|
||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logger = logging.getLogger()
|
||||
@ -32,9 +31,8 @@ embedding_model_dict = {
|
||||
# 选用的 Embedding 名称
|
||||
EMBEDDING_MODEL = "m3e-base"
|
||||
|
||||
# Embedding 模型运行设备
|
||||
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
EMBEDDING_DEVICE = "auto"
|
||||
|
||||
llm_model_dict = {
|
||||
"chatglm-6b": {
|
||||
@ -69,20 +67,32 @@ llm_model_dict = {
|
||||
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in
|
||||
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
|
||||
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
|
||||
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
|
||||
"gpt-3.5-turbo": {
|
||||
"local_model_path": "gpt-3.5-turbo",
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"openai_proxy": os.environ.get("OPENAI_PROXY")
|
||||
},
|
||||
# 线上模型。当前支持智谱AI。
|
||||
# 如果没有设置有效的local_model_path,则认为是在线模型API。
|
||||
# 请在server_config中为每个在线API设置不同的端口
|
||||
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||
"chatglm-api": {
|
||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
||||
"api_key": os.environ.get("ZHIPUAI_API_KEY"),
|
||||
"provider": "ChatGLMWorker",
|
||||
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# LLM 名称
|
||||
LLM_MODEL = "chatglm2-6b"
|
||||
|
||||
# LLM 运行设备
|
||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
# 历史对话轮数
|
||||
HISTORY_LEN = 3
|
||||
|
||||
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
LLM_DEVICE = "auto"
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||
@ -164,4 +174,4 @@ BING_SUBSCRIPTION_KEY = ""
|
||||
# 是否开启中文标题加强,以及标题增强的相关配置
|
||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||
ZH_TITLE_ENHANCE = False
|
||||
ZH_TITLE_ENHANCE = False
|
||||
|
||||
@ -1,4 +1,8 @@
|
||||
from .model_config import LLM_MODEL, LLM_DEVICE
|
||||
from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
|
||||
import httpx
|
||||
|
||||
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
|
||||
HTTPX_DEFAULT_TIMEOUT = 300.0
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = {
|
||||
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
|
||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||
FSCHAT_MODEL_WORKERS = {
|
||||
LLM_MODEL: {
|
||||
# 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。
|
||||
"default": {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 20002,
|
||||
"device": LLM_DEVICE,
|
||||
# todo: 多卡加载需要配置的参数
|
||||
"gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
|
||||
"num_gpus": 1, # 使用GPU的数量
|
||||
# 以下为非常用参数,可根据需要配置
|
||||
|
||||
# 多卡加载需要配置的参数
|
||||
# "gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
|
||||
# "num_gpus": 1, # 使用GPU的数量
|
||||
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
|
||||
|
||||
# 以下为非常用参数,可根据需要配置
|
||||
# "load_8bit": False, # 开启8bit量化
|
||||
# "cpu_offloading": None,
|
||||
# "gptq_ckpt": None,
|
||||
@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = {
|
||||
# "stream_interval": 2,
|
||||
# "no_register": False,
|
||||
},
|
||||
"baichuan-7b": { # 使用default中的IP和端口
|
||||
"device": "cpu",
|
||||
},
|
||||
"chatglm-api": { # 请为每个在线API设置不同的端口
|
||||
"port": 20003,
|
||||
},
|
||||
}
|
||||
|
||||
# fastchat multi model worker server
|
||||
FSCHAT_MULTI_MODEL_WORKERS = {
|
||||
# todo
|
||||
# TODO:
|
||||
}
|
||||
|
||||
# fastchat controller server
|
||||
@ -66,35 +79,3 @@ FSCHAT_CONTROLLER = {
|
||||
"port": 20001,
|
||||
"dispatch_method": "shortest_queue",
|
||||
}
|
||||
|
||||
|
||||
# 以下不要更改
|
||||
def fschat_controller_address() -> str:
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
|
||||
if model := FSCHAT_MODEL_WORKERS.get(model_name):
|
||||
host = model["host"]
|
||||
port = model["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def fschat_openai_api_address() -> str:
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def api_address() -> str:
|
||||
host = API_SERVER["host"]
|
||||
port = API_SERVER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def webui_address() -> str:
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
2
document_loaders/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .mypdfloader import RapidOCRPDFLoader
|
||||
from .myimgloader import RapidOCRLoader
|
||||
25
document_loaders/myimgloader.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import List
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
|
||||
|
||||
class RapidOCRLoader(UnstructuredFileLoader):
|
||||
def _get_elements(self) -> List:
|
||||
def img2text(filepath):
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
resp = ""
|
||||
ocr = RapidOCR()
|
||||
result, _ = ocr(filepath)
|
||||
if result:
|
||||
ocr_result = [line[1] for line in result]
|
||||
resp += "\n".join(ocr_result)
|
||||
return resp
|
||||
|
||||
text = img2text(self.file_path)
|
||||
from unstructured.partition.text import partition_text
|
||||
return partition_text(text=text, **self.unstructured_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg")
|
||||
docs = loader.load()
|
||||
print(docs)
|
||||
37
document_loaders/mypdfloader.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import List
|
||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||
|
||||
|
||||
class RapidOCRPDFLoader(UnstructuredFileLoader):
|
||||
def _get_elements(self) -> List:
|
||||
def pdf2text(filepath):
|
||||
import fitz
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
import numpy as np
|
||||
ocr = RapidOCR()
|
||||
doc = fitz.open(filepath)
|
||||
resp = ""
|
||||
for page in doc:
|
||||
# TODO: 依据文本与图片顺序调整处理方式
|
||||
text = page.get_text("")
|
||||
resp += text + "\n"
|
||||
|
||||
img_list = page.get_images()
|
||||
for img in img_list:
|
||||
pix = fitz.Pixmap(doc, img[0])
|
||||
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
|
||||
result, _ = ocr(img_array)
|
||||
if result:
|
||||
ocr_result = [line[1] for line in result]
|
||||
resp += "\n".join(ocr_result)
|
||||
return resp
|
||||
|
||||
text = pdf2text(self.file_path)
|
||||
from unstructured.partition.text import partition_text
|
||||
return partition_text(text=text, **self.unstructured_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf")
|
||||
docs = loader.load()
|
||||
print(docs)
|
||||
|
Before Width: | Height: | Size: 272 KiB |
|
Before Width: | Height: | Size: 284 KiB |
|
Before Width: | Height: | Size: 281 KiB |
|
Before Width: | Height: | Size: 292 KiB |
|
Before Width: | Height: | Size: 269 KiB |
|
Before Width: | Height: | Size: 291 KiB |
|
Before Width: | Height: | Size: 200 KiB |
BIN
img/qr_code_58.jpg
Normal file
|
After Width: | Height: | Size: 249 KiB |
@ -1,8 +1,9 @@
|
||||
from server.knowledge_base.migrate import create_tables, folder2db, recreate_all_vs, list_kbs_from_folder
|
||||
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
from startup import dump_server_info
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -25,13 +26,19 @@ if __name__ == "__main__":
|
||||
|
||||
dump_server_info()
|
||||
|
||||
create_tables()
|
||||
print("database talbes created")
|
||||
start_time = datetime.now()
|
||||
|
||||
if args.recreate_vs:
|
||||
reset_tables()
|
||||
print("database talbes reseted")
|
||||
print("recreating all vector stores")
|
||||
recreate_all_vs()
|
||||
else:
|
||||
create_tables()
|
||||
print("database talbes created")
|
||||
print("filling kb infos to database")
|
||||
for kb in list_kbs_from_folder():
|
||||
folder2db(kb, "fill_info_only")
|
||||
|
||||
end_time = datetime.now()
|
||||
print(f"总计用时: {end_time-start_time}")
|
||||
|
||||
@ -15,6 +15,8 @@ SQLAlchemy==2.0.19
|
||||
faiss-cpu
|
||||
accelerate
|
||||
spacy
|
||||
PyMuPDF==1.22.5
|
||||
rapidocr_onnxruntime>=1.3.1
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
|
||||
@ -16,6 +16,8 @@ faiss-cpu
|
||||
nltk
|
||||
accelerate
|
||||
spacy
|
||||
PyMuPDF==1.22.5
|
||||
rapidocr_onnxruntime>=1.3.1
|
||||
|
||||
# uncomment libs if you want to use corresponding vector store
|
||||
# pymilvus==2.1.3 # requires milvus==2.1.3
|
||||
|
||||
@ -4,20 +4,22 @@ import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
|
||||
from configs import VERSION
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
search_engine_chat)
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
|
||||
update_doc, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
@ -84,11 +86,11 @@ def create_app():
|
||||
summary="删除知识库"
|
||||
)(delete_kb)
|
||||
|
||||
app.get("/knowledge_base/list_docs",
|
||||
app.get("/knowledge_base/list_files",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=ListResponse,
|
||||
summary="获取知识库内的文件列表"
|
||||
)(list_docs)
|
||||
)(list_files)
|
||||
|
||||
app.post("/knowledge_base/search_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
@ -123,6 +125,75 @@ def create_app():
|
||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
)(recreate_vector_store)
|
||||
|
||||
# LLM模型相关接口
|
||||
@app.post("/llm_model/list_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出当前已加载的模型")
|
||||
def list_models(
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
从fastchat controller获取已加载模型列表
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(controller_address + "/list_models")
|
||||
return BaseResponse(data=r.json()["models"])
|
||||
except Exception as e:
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
data=[],
|
||||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
@app.post("/llm_model/stop",
|
||||
tags=["LLM Model Management"],
|
||||
summary="停止指定的LLM模型(Model Worker)",
|
||||
)
|
||||
def stop_llm_model(
|
||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
向fastchat controller请求停止某个LLM模型。
|
||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name},
|
||||
)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
@app.post("/llm_model/change",
|
||||
tags=["LLM Model Management"],
|
||||
summary="切换指定的LLM模型(Model Worker)",
|
||||
)
|
||||
def change_llm_model(
|
||||
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
|
||||
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
):
|
||||
'''
|
||||
向fastchat controller请求切换LLM模型。
|
||||
'''
|
||||
try:
|
||||
controller_address = controller_address or fschat_controller_address()
|
||||
r = httpx.post(
|
||||
controller_address + "/release_worker",
|
||||
json={"model_name": model_name, "new_model_name": new_model_name},
|
||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
||||
)
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
return BaseResponse(
|
||||
code=500,
|
||||
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
):
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
history: List[History] = [],
|
||||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
||||
@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
|
||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model_name=model_name,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
|
||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||
@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
|
||||
|
||||
await task
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history),
|
||||
return StreamingResponse(chat_iterator(query, history, model_name),
|
||||
media_type="text/event-stream")
|
||||
|
||||
@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
):
|
||||
@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||
kb: KBService,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
|
||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model_name=model_name,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||
|
||||
await task
|
||||
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
|
||||
media_type="text/event-stream")
|
||||
|
||||
@ -29,23 +29,22 @@ async def openai_chat(msg: OpenAiChatMsgIn):
|
||||
print(f"{openai.api_base=}")
|
||||
print(msg)
|
||||
|
||||
async def get_response(msg):
|
||||
def get_response(msg):
|
||||
data = msg.dict()
|
||||
data["streaming"] = True
|
||||
data.pop("stream")
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(**data)
|
||||
if msg.stream:
|
||||
for chunk in response.choices[0].message.content:
|
||||
print(chunk)
|
||||
yield chunk
|
||||
for data in response:
|
||||
if choices := data.choices:
|
||||
if chunk := choices[0].get("delta", {}).get("content"):
|
||||
print(chunk, end="", flush=True)
|
||||
yield chunk
|
||||
else:
|
||||
answer = ""
|
||||
for chunk in response.choices[0].message.content:
|
||||
answer += chunk
|
||||
print(answer)
|
||||
yield(answer)
|
||||
if response.choices:
|
||||
answer = response.choices[0].message.content
|
||||
print(answer)
|
||||
yield(answer)
|
||||
except Exception as e:
|
||||
print(type(e))
|
||||
logger.error(e)
|
||||
|
||||
@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callbacks=[callback],
|
||||
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL,
|
||||
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
|
||||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model_name=model_name,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
|
||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||
@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||
ensure_ascii=False)
|
||||
await task
|
||||
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
|
||||
media_type="text/event-stream")
|
||||
|
||||
@ -3,10 +3,14 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs.model_config import SQLALCHEMY_DATABASE_URI
|
||||
import json
|
||||
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URI)
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URI,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@ -9,9 +9,9 @@ class KnowledgeBaseModel(Base):
|
||||
"""
|
||||
__tablename__ = 'knowledge_base'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID')
|
||||
kb_name = Column(String, comment='知识库名称')
|
||||
vs_type = Column(String, comment='嵌入模型类型')
|
||||
embed_model = Column(String, comment='嵌入模型名称')
|
||||
kb_name = Column(String(50), comment='知识库名称')
|
||||
vs_type = Column(String(50), comment='向量库类型')
|
||||
embed_model = Column(String(50), comment='嵌入模型名称')
|
||||
file_count = Column(Integer, default=0, comment='文件数量')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, func
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
|
||||
|
||||
from server.db.base import Base
|
||||
|
||||
@ -9,13 +9,32 @@ class KnowledgeFileModel(Base):
|
||||
"""
|
||||
__tablename__ = 'knowledge_file'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID')
|
||||
file_name = Column(String, comment='文件名')
|
||||
file_ext = Column(String, comment='文件扩展名')
|
||||
kb_name = Column(String, comment='所属知识库名称')
|
||||
document_loader_name = Column(String, comment='文档加载器名称')
|
||||
text_splitter_name = Column(String, comment='文本分割器名称')
|
||||
file_name = Column(String(255), comment='文件名')
|
||||
file_ext = Column(String(10), comment='文件扩展名')
|
||||
kb_name = Column(String(50), comment='所属知识库名称')
|
||||
document_loader_name = Column(String(50), comment='文档加载器名称')
|
||||
text_splitter_name = Column(String(50), comment='文本分割器名称')
|
||||
file_version = Column(Integer, default=1, comment='文件版本')
|
||||
file_mtime = Column(Float, default=0.0, comment="文件修改时间")
|
||||
file_size = Column(Integer, default=0, comment="文件大小")
|
||||
custom_docs = Column(Boolean, default=False, comment="是否自定义docs")
|
||||
docs_count = Column(Integer, default=0, comment="切分文档数量")
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
def __repr__(self):
|
||||
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"
|
||||
|
||||
|
||||
class FileDocModel(Base):
|
||||
"""
|
||||
文件-向量库文档模型
|
||||
"""
|
||||
__tablename__ = 'file_doc'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
|
||||
kb_name = Column(String(50), comment='知识库名称')
|
||||
file_name = Column(String(255), comment='文件名称')
|
||||
doc_id = Column(String(50), comment="向量库文档ID")
|
||||
meta_data = Column(JSON, default={})
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.metadata}')>"
|
||||
|
||||
@ -1,24 +1,101 @@
|
||||
from server.db.models.knowledge_base_model import KnowledgeBaseModel
|
||||
from server.db.models.knowledge_file_model import KnowledgeFileModel
|
||||
from server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel
|
||||
from server.db.session import with_session
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
@with_session
|
||||
def list_docs_from_db(session, kb_name):
|
||||
def list_docs_from_db(session,
|
||||
kb_name: str,
|
||||
file_name: str = None,
|
||||
metadata: Dict = {},
|
||||
) -> List[Dict]:
|
||||
'''
|
||||
列出某知识库某文件对应的所有Document。
|
||||
返回形式:[{"id": str, "metadata": dict}, ...]
|
||||
'''
|
||||
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
|
||||
if file_name:
|
||||
docs = docs.filter_by(file_name=file_name)
|
||||
for k, v in metadata.items():
|
||||
docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))
|
||||
|
||||
return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()]
|
||||
|
||||
|
||||
@with_session
|
||||
def delete_docs_from_db(session,
|
||||
kb_name: str,
|
||||
file_name: str = None,
|
||||
) -> List[Dict]:
|
||||
'''
|
||||
删除某知识库某文件对应的所有Document,并返回被删除的Document。
|
||||
返回形式:[{"id": str, "metadata": dict}, ...]
|
||||
'''
|
||||
docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
|
||||
query = session.query(FileDocModel).filter_by(kb_name=kb_name)
|
||||
if file_name:
|
||||
query = query.filter_by(file_name=file_name)
|
||||
query.delete()
|
||||
session.commit()
|
||||
return docs
|
||||
|
||||
|
||||
@with_session
|
||||
def add_docs_to_db(session,
|
||||
kb_name: str,
|
||||
file_name: str,
|
||||
doc_infos: List[Dict]):
|
||||
'''
|
||||
将某知识库某文件对应的所有Document信息添加到数据库。
|
||||
doc_infos形式:[{"id": str, "metadata": dict}, ...]
|
||||
'''
|
||||
for d in doc_infos:
|
||||
obj = FileDocModel(
|
||||
kb_name=kb_name,
|
||||
file_name=file_name,
|
||||
doc_id=d["id"],
|
||||
meta_data=d["metadata"],
|
||||
)
|
||||
session.add(obj)
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
def count_files_from_db(session, kb_name: str) -> int:
|
||||
return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count()
|
||||
|
||||
|
||||
@with_session
|
||||
def list_files_from_db(session, kb_name):
|
||||
files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all()
|
||||
docs = [f.file_name for f in files]
|
||||
return docs
|
||||
|
||||
|
||||
@with_session
|
||||
def add_doc_to_db(session, kb_file: KnowledgeFile):
|
||||
def add_file_to_db(session,
|
||||
kb_file: KnowledgeFile,
|
||||
docs_count: int = 0,
|
||||
custom_docs: bool = False,
|
||||
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...]
|
||||
):
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||
if kb:
|
||||
# 如果已经存在该文件,则更新文件版本号
|
||||
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
|
||||
kb_name=kb_file.kb_name).first()
|
||||
# 如果已经存在该文件,则更新文件信息与版本号
|
||||
existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel)
|
||||
.filter_by(file_name=kb_file.filename,
|
||||
kb_name=kb_file.kb_name)
|
||||
.first())
|
||||
mtime = kb_file.get_mtime()
|
||||
size = kb_file.get_size()
|
||||
|
||||
if existing_file:
|
||||
existing_file.file_mtime = mtime
|
||||
existing_file.file_size = size
|
||||
existing_file.docs_count = docs_count
|
||||
existing_file.custom_docs = custom_docs
|
||||
existing_file.file_version += 1
|
||||
# 否则,添加新文件
|
||||
else:
|
||||
@ -28,9 +105,14 @@ def add_doc_to_db(session, kb_file: KnowledgeFile):
|
||||
kb_name=kb_file.kb_name,
|
||||
document_loader_name=kb_file.document_loader_name,
|
||||
text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter",
|
||||
file_mtime=mtime,
|
||||
file_size=size,
|
||||
docs_count = docs_count,
|
||||
custom_docs=custom_docs,
|
||||
)
|
||||
kb.file_count += 1
|
||||
session.add(new_file)
|
||||
add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos)
|
||||
return True
|
||||
|
||||
|
||||
@ -40,6 +122,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
|
||||
kb_name=kb_file.kb_name).first()
|
||||
if existing_file:
|
||||
session.delete(existing_file)
|
||||
delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
|
||||
session.commit()
|
||||
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||
@ -52,7 +135,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
|
||||
@with_session
|
||||
def delete_files_from_db(session, knowledge_base_name: str):
|
||||
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
|
||||
|
||||
session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete()
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
|
||||
if kb:
|
||||
kb.file_count = 0
|
||||
@ -62,7 +145,7 @@ def delete_files_from_db(session, knowledge_base_name: str):
|
||||
|
||||
|
||||
@with_session
|
||||
def doc_exists(session, kb_file: KnowledgeFile):
|
||||
def file_exists_in_db(session, kb_file: KnowledgeFile):
|
||||
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
|
||||
kb_name=kb_file.kb_name).first()
|
||||
return True if existing_file else False
|
||||
@ -82,6 +165,10 @@ def get_file_detail(session, kb_name: str, filename: str) -> dict:
|
||||
"document_loader": file.document_loader_name,
|
||||
"text_splitter": file.text_splitter_name,
|
||||
"create_time": file.create_time,
|
||||
"file_mtime": file.file_mtime,
|
||||
"file_size": file.file_size,
|
||||
"custom_docs": file.custom_docs,
|
||||
"docs_count": file.docs_count,
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@ -3,7 +3,7 @@ import urllib
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
|
||||
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
@ -29,7 +29,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
||||
return data
|
||||
|
||||
|
||||
async def list_docs(
|
||||
async def list_files(
|
||||
knowledge_base_name: str
|
||||
) -> ListResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
@ -40,7 +40,7 @@ async def list_docs(
|
||||
if kb is None:
|
||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
else:
|
||||
all_doc_names = kb.list_docs()
|
||||
all_doc_names = kb.list_files()
|
||||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
@ -183,14 +183,14 @@ async def recreate_vector_store(
|
||||
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
||||
'''
|
||||
|
||||
async def output():
|
||||
def output():
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
kb.create_kb()
|
||||
kb.clear_vs()
|
||||
docs = list_docs_from_folder(knowledge_base_name)
|
||||
docs = list_files_from_folder(knowledge_base_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, knowledge_base_name)
|
||||
|
||||
@ -1,25 +1,32 @@
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
from server.db.repository.knowledge_base_repository import (
|
||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||
load_kb_from_db, get_kb_detail,
|
||||
)
|
||||
from server.db.repository.knowledge_file_repository import (
|
||||
add_doc_to_db, delete_file_from_db, delete_files_from_db, doc_exists,
|
||||
list_docs_from_db, get_file_detail, delete_file_from_db
|
||||
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
||||
list_docs_from_db,
|
||||
)
|
||||
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||
EMBEDDING_MODEL)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||
list_kbs_from_folder, list_docs_from_folder,
|
||||
list_kbs_from_folder, list_files_from_folder,
|
||||
)
|
||||
from server.utils import embedding_device
|
||||
from typing import List, Union, Dict
|
||||
from typing import List, Union, Dict, Optional
|
||||
|
||||
|
||||
class SupportedVSType:
|
||||
@ -41,7 +48,7 @@ class KBService(ABC):
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
self.do_init()
|
||||
|
||||
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
|
||||
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||
return load_embeddings(self.embed_model, embed_device)
|
||||
|
||||
def create_kb(self):
|
||||
@ -62,7 +69,6 @@ class KBService(ABC):
|
||||
status = delete_files_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
|
||||
def drop_kb(self):
|
||||
"""
|
||||
删除知识库
|
||||
@ -71,16 +77,24 @@ class KBService(ABC):
|
||||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
if docs:
|
||||
custom_docs = True
|
||||
else:
|
||||
docs = kb_file.file2text()
|
||||
custom_docs = False
|
||||
|
||||
if docs:
|
||||
self.delete_doc(kb_file)
|
||||
embeddings = self._load_embeddings()
|
||||
self.do_add_doc(docs, embeddings, **kwargs)
|
||||
status = add_doc_to_db(kb_file)
|
||||
doc_infos = self.do_add_doc(docs, **kwargs)
|
||||
status = add_file_to_db(kb_file,
|
||||
custom_docs=custom_docs,
|
||||
docs_count=len(docs),
|
||||
doc_infos=doc_infos)
|
||||
else:
|
||||
status = False
|
||||
return status
|
||||
@ -95,20 +109,24 @@ class KBService(ABC):
|
||||
os.remove(kb_file.filepath)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
使用content中的文件更新向量库
|
||||
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
self.delete_doc(kb_file, **kwargs)
|
||||
return self.add_doc(kb_file, **kwargs)
|
||||
|
||||
return self.add_doc(kb_file, docs=docs, **kwargs)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
filename=file_name))
|
||||
|
||||
def list_docs(self):
|
||||
return list_docs_from_db(self.kb_name)
|
||||
def list_files(self):
|
||||
return list_files_from_db(self.kb_name)
|
||||
|
||||
def count_files(self):
|
||||
return count_files_from_db(self.kb_name)
|
||||
|
||||
def search_docs(self,
|
||||
query: str,
|
||||
@ -119,6 +137,18 @@ class KBService(ABC):
|
||||
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
||||
return docs
|
||||
|
||||
# TODO: milvus/pg需要实现该方法
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
return None
|
||||
|
||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
|
||||
'''
|
||||
通过file_name或metadata检索Document
|
||||
'''
|
||||
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
def do_create_kb(self):
|
||||
"""
|
||||
@ -168,8 +198,7 @@ class KBService(ABC):
|
||||
@abstractmethod
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
embeddings: Embeddings,
|
||||
):
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
向知识库添加文档子类实自己逻辑
|
||||
"""
|
||||
@ -208,8 +237,9 @@ class KBServiceFactory:
|
||||
return PGKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.MILVUS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||
return MilvusKBService(kb_name,
|
||||
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
return DefaultKBService(kb_name)
|
||||
|
||||
@ -217,7 +247,7 @@ class KBServiceFactory:
|
||||
def get_service_by_name(kb_name: str
|
||||
) -> KBService:
|
||||
_, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
|
||||
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
|
||||
vs_type = "faiss"
|
||||
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||
|
||||
@ -256,29 +286,30 @@ def get_kb_details() -> List[Dict]:
|
||||
for i, v in enumerate(result.values()):
|
||||
v['No'] = i + 1
|
||||
data.append(v)
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_kb_doc_details(kb_name: str) -> List[Dict]:
|
||||
def get_kb_file_details(kb_name: str) -> List[Dict]:
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
docs_in_folder = list_docs_from_folder(kb_name)
|
||||
docs_in_db = kb.list_docs()
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files_in_db = kb.list_files()
|
||||
result = {}
|
||||
|
||||
for doc in docs_in_folder:
|
||||
for doc in files_in_folder:
|
||||
result[doc] = {
|
||||
"kb_name": kb_name,
|
||||
"file_name": doc,
|
||||
"file_ext": os.path.splitext(doc)[-1],
|
||||
"file_version": 0,
|
||||
"document_loader": "",
|
||||
"docs_count": 0,
|
||||
"text_splitter": "",
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
for doc in docs_in_db:
|
||||
for doc in files_in_db:
|
||||
doc_detail = get_file_detail(kb_name, doc)
|
||||
if doc_detail:
|
||||
doc_detail["in_db"] = True
|
||||
@ -292,5 +323,39 @@ def get_kb_doc_details(kb_name: str) -> List[Dict]:
|
||||
for i, v in enumerate(result.values()):
|
||||
v['No'] = i + 1
|
||||
data.append(v)
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingsFunAdapter(Embeddings):
|
||||
|
||||
def __init__(self, embeddings: Embeddings):
|
||||
self.embeddings = embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return normalize(self.embeddings.embed_documents(texts))
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
query_embed = self.embeddings.embed_query(text)
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return await normalize(self.embeddings.aembed_documents(texts))
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return await normalize(self.embeddings.aembed_query(text))
|
||||
|
||||
|
||||
def score_threshold_process(score_threshold, k, docs):
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
operator.le
|
||||
)
|
||||
docs = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs
|
||||
if cmp(similarity, score_threshold)
|
||||
]
|
||||
return docs[:k]
|
||||
|
||||
@ -13,7 +13,7 @@ class DefaultKBService(KBService):
|
||||
def do_drop_kb(self):
|
||||
pass
|
||||
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
|
||||
def do_add_doc(self, docs: List[Document]):
|
||||
pass
|
||||
|
||||
def do_clear_vs(self):
|
||||
|
||||
@ -5,7 +5,6 @@ from configs.model_config import (
|
||||
KB_ROOT_PATH,
|
||||
CACHED_VS_NUM,
|
||||
EMBEDDING_MODEL,
|
||||
EMBEDDING_DEVICE,
|
||||
SCORE_THRESHOLD
|
||||
)
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
@ -13,40 +12,22 @@ from functools import lru_cache
|
||||
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from typing import List
|
||||
from typing import List, Dict, Optional
|
||||
from langchain.docstore.document import Document
|
||||
from server.utils import torch_gc
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
def _embeddings_hash(self):
|
||||
if isinstance(self, HuggingFaceEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, HuggingFaceBgeEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, OpenAIEmbeddings):
|
||||
return hash(self.model)
|
||||
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
OpenAIEmbeddings.__hash__ = _embeddings_hash
|
||||
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
from server.utils import torch_gc, embedding_device
|
||||
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
@lru_cache(CACHED_VS_NUM)
|
||||
def load_vector_store(
|
||||
def load_faiss_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = EMBEDDING_DEVICE,
|
||||
embed_device: str = embedding_device(),
|
||||
embeddings: Embeddings = None,
|
||||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
):
|
||||
) -> FAISS:
|
||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
if embeddings is None:
|
||||
@ -86,22 +67,39 @@ class FaissKBService(KBService):
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.FAISS
|
||||
|
||||
@staticmethod
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store")
|
||||
def get_vs_path(self):
|
||||
return os.path.join(self.get_kb_path(), "vector_store")
|
||||
|
||||
@staticmethod
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
def get_kb_path(self):
|
||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||
|
||||
def load_vector_store(self) -> FAISS:
|
||||
return load_faiss_vector_store(
|
||||
knowledge_base_name=self.kb_name,
|
||||
embed_model=self.embed_model,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
|
||||
)
|
||||
|
||||
def save_vector_store(self, vector_store: FAISS = None):
|
||||
vector_store = vector_store or self.load_vector_store()
|
||||
vector_store.save_local(self.vs_path)
|
||||
return vector_store
|
||||
|
||||
def refresh_vs_cache(self):
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
vector_store = self.load_vector_store()
|
||||
return vector_store.docstore._dict.get(id)
|
||||
|
||||
def do_init(self):
|
||||
self.kb_path = FaissKBService.get_kb_path(self.kb_name)
|
||||
self.vs_path = FaissKBService.get_vs_path(self.kb_name)
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
|
||||
def do_create_kb(self):
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
load_vector_store(self.kb_name)
|
||||
self.load_vector_store()
|
||||
|
||||
def do_drop_kb(self):
|
||||
self.clear_vs()
|
||||
@ -113,33 +111,27 @@ class FaissKBService(KBService):
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
embeddings: Embeddings = None,
|
||||
) -> List[Document]:
|
||||
search_index = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||
search_index = self.load_vector_store()
|
||||
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
embeddings: Embeddings,
|
||||
**kwargs,
|
||||
):
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
vector_store.add_documents(docs)
|
||||
) -> List[Dict]:
|
||||
vector_store = self.load_vector_store()
|
||||
ids = vector_store.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
torch_gc()
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
self.refresh_vs_cache()
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile,
|
||||
**kwargs):
|
||||
embeddings = self._load_embeddings()
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
vector_store = self.load_vector_store()
|
||||
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
@ -148,14 +140,14 @@ class FaissKBService(KBService):
|
||||
vector_store.delete(ids)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
self.refresh_vs_cache()
|
||||
|
||||
return True
|
||||
return vector_store
|
||||
|
||||
def do_clear_vs(self):
|
||||
shutil.rmtree(self.vs_path)
|
||||
os.makedirs(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
self.refresh_vs_cache()
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
if super().exist_doc(file_name):
|
||||
@ -166,3 +158,11 @@ class FaissKBService(KBService):
|
||||
return "in_folder"
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
faissService = FaissKBService("test")
|
||||
faissService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
faissService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
faissService.do_drop_kb()
|
||||
print(faissService.search_docs("如何启动api服务"))
|
||||
@ -1,12 +1,16 @@
|
||||
from typing import List
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from faiss import normalize_L2
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Milvus
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
from configs.model_config import SCORE_THRESHOLD, kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
|
||||
@ -18,6 +22,14 @@ class MilvusKBService(KBService):
|
||||
from pymilvus import Collection
|
||||
return Collection(milvus_name)
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
if self.milvus.col:
|
||||
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
|
||||
if len(data_list) > 0:
|
||||
data = data_list[0]
|
||||
text = data.pop("text")
|
||||
return Document(page_content=text, metadata=data)
|
||||
|
||||
@staticmethod
|
||||
def search(milvus_name, content, limit=3):
|
||||
search_params = {
|
||||
@ -36,38 +48,31 @@ class MilvusKBService(KBService):
|
||||
def _load_milvus(self, embeddings: Embeddings = None):
|
||||
if embeddings is None:
|
||||
embeddings = self._load_embeddings()
|
||||
self.milvus = Milvus(embedding_function=embeddings,
|
||||
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
|
||||
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
||||
|
||||
def do_init(self):
|
||||
self._load_milvus()
|
||||
|
||||
def do_drop_kb(self):
|
||||
self.milvus.col.drop()
|
||||
if self.milvus.col:
|
||||
self.milvus.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings):
|
||||
# todo: support score threshold
|
||||
self._load_milvus(embeddings=embeddings)
|
||||
return self.milvus.similarity_search_with_score(query, top_k)
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
|
||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
self.milvus.add_documents(docs)
|
||||
from server.db.repository.knowledge_file_repository import add_doc_to_db
|
||||
status = add_doc_to_db(kb_file)
|
||||
return status
|
||||
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
|
||||
pass
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
ids = self.milvus.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
delete_list = [item.get("pk") for item in
|
||||
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||
self.milvus.col.delete(expr=f'pk in {delete_list}')
|
||||
if self.milvus.col:
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
delete_list = [item.get("pk") for item in
|
||||
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||
self.milvus.col.delete(expr=f'pk in {delete_list}')
|
||||
|
||||
def do_clear_vs(self):
|
||||
if self.milvus.col:
|
||||
@ -80,7 +85,9 @@ if __name__ == '__main__':
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
milvusService = MilvusKBService("test")
|
||||
milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
milvusService.do_drop_kb()
|
||||
print(milvusService.search_docs("测试"))
|
||||
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
|
||||
print(milvusService.get_doc_by_id("444022434274215486"))
|
||||
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# milvusService.do_drop_kb()
|
||||
# print(milvusService.search_docs("如何启动api服务"))
|
||||
|
||||
@ -1,25 +1,38 @@
|
||||
from typing import List
|
||||
import json
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import PGVector
|
||||
from langchain.vectorstores.pgvector import DistanceStrategy
|
||||
from sqlalchemy import text
|
||||
|
||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService
|
||||
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
||||
from server.utils import embedding_device as get_embedding_device
|
||||
|
||||
|
||||
class PGKBService(KBService):
|
||||
pg_vector: PGVector
|
||||
|
||||
def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
|
||||
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
|
||||
_embeddings = embeddings
|
||||
if _embeddings is None:
|
||||
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||
self.pg_vector = PGVector(embedding_function=_embeddings,
|
||||
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
|
||||
collection_name=self.kb_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
with self.pg_vector.connect() as connect:
|
||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
|
||||
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
||||
connect.execute(stmt, parameters={'id': id}).fetchall()]
|
||||
if len(results) > 0:
|
||||
return results[0]
|
||||
|
||||
def do_init(self):
|
||||
self._load_pg_vector()
|
||||
@ -44,22 +57,14 @@ class PGKBService(KBService):
|
||||
connect.commit()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
|
||||
# todo: support score threshold
|
||||
self._load_pg_vector(embeddings=embeddings)
|
||||
return self.pg_vector.similarity_search_with_score(query, top_k)
|
||||
return score_threshold_process(score_threshold, top_k,
|
||||
self.pg_vector.similarity_search_with_score(query, top_k))
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
self.pg_vector.add_documents(docs)
|
||||
from server.db.repository.knowledge_file_repository import add_doc_to_db
|
||||
status = add_doc_to_db(kb_file)
|
||||
return status
|
||||
|
||||
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
|
||||
pass
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
ids = self.pg_vector.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
with self.pg_vector.connect() as connect:
|
||||
@ -77,10 +82,11 @@ class PGKBService(KBService):
|
||||
if __name__ == '__main__':
|
||||
from server.db.base import Base, engine
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
# Base.metadata.create_all(bind=engine)
|
||||
pGKBService = PGKBService("test")
|
||||
pGKBService.create_kb()
|
||||
pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
pGKBService.drop_kb()
|
||||
print(pGKBService.search_docs("测试"))
|
||||
# pGKBService.create_kb()
|
||||
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.drop_kb()
|
||||
print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772"))
|
||||
# print(pGKBService.search_docs("如何启动api服务"))
|
||||
|
||||
@ -1,10 +1,17 @@
|
||||
from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE
|
||||
from server.knowledge_base.utils import get_file_path, list_kbs_from_folder, list_docs_from_folder, KnowledgeFile
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_file_repository import add_doc_to_db
|
||||
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
||||
list_files_from_folder, run_in_thread_pool,
|
||||
files2docs_in_thread,
|
||||
KnowledgeFile,)
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
|
||||
from server.db.repository.knowledge_file_repository import add_file_to_db
|
||||
from server.db.base import Base, engine
|
||||
import os
|
||||
from typing import Literal, Callable, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Literal, Any, List
|
||||
|
||||
|
||||
pool = ThreadPoolExecutor(os.cpu_count())
|
||||
|
||||
|
||||
def create_tables():
|
||||
@ -16,13 +23,22 @@ def reset_tables():
|
||||
create_tables()
|
||||
|
||||
|
||||
def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
|
||||
kb_files = []
|
||||
for file in files:
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name)
|
||||
kb_files.append(kb_file)
|
||||
except Exception as e:
|
||||
print(f"{e},已跳过")
|
||||
return kb_files
|
||||
|
||||
|
||||
def folder2db(
|
||||
kb_name: str,
|
||||
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
|
||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
callback_before: Callable = None,
|
||||
callback_after: Callable = None,
|
||||
):
|
||||
'''
|
||||
use existed files in local folder to populate database and/or vector store.
|
||||
@ -36,70 +52,62 @@ def folder2db(
|
||||
kb.create_kb()
|
||||
|
||||
if mode == "recreate_vs":
|
||||
files_count = kb.count_files()
|
||||
print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。")
|
||||
kb.clear_vs()
|
||||
docs = list_docs_from_folder(kb_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
files_count = kb.count_files()
|
||||
print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。")
|
||||
|
||||
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
|
||||
for success, result in files2docs_in_thread(kb_files, pool=pool):
|
||||
if success:
|
||||
_, filename, docs = result
|
||||
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
|
||||
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
||||
else:
|
||||
print(result)
|
||||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
kb.refresh_vs_cache()
|
||||
elif mode == "fill_info_only":
|
||||
docs = list_docs_from_folder(kb_name)
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
add_doc_to_db(kb_file)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
files = list_files_from_folder(kb_name)
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
|
||||
for kb_file in kb_file:
|
||||
add_file_to_db(kb_file)
|
||||
print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
|
||||
elif mode == "update_in_db":
|
||||
docs = kb.list_docs()
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
files = kb.list_files()
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
|
||||
for kb_file in kb_files:
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
kb.refresh_vs_cache()
|
||||
elif mode == "increament":
|
||||
db_docs = kb.list_docs()
|
||||
folder_docs = list_docs_from_folder(kb_name)
|
||||
docs = list(set(folder_docs) - set(db_docs))
|
||||
for i, doc in enumerate(docs):
|
||||
try:
|
||||
kb_file = KnowledgeFile(doc, kb_name)
|
||||
if callable(callback_before):
|
||||
callback_before(kb_file, i, docs)
|
||||
if i == len(docs) - 1:
|
||||
not_refresh_vs_cache = False
|
||||
else:
|
||||
not_refresh_vs_cache = True
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
|
||||
if callable(callback_after):
|
||||
callback_after(kb_file, i, docs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
db_files = kb.list_files()
|
||||
folder_files = list_files_from_folder(kb_name)
|
||||
files = list(set(folder_files) - set(db_files))
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
|
||||
for success, result in files2docs_in_thread(kb_files, pool=pool):
|
||||
if success:
|
||||
_, filename, docs = result
|
||||
print(f"正在将 {kb_name}/{filename} 添加到向量库")
|
||||
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
|
||||
else:
|
||||
print(result)
|
||||
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
kb.refresh_vs_cache()
|
||||
else:
|
||||
raise ValueError(f"unspported migrate mode: {mode}")
|
||||
print(f"unspported migrate mode: {mode}")
|
||||
|
||||
|
||||
def recreate_all_vs(
|
||||
@ -114,30 +122,34 @@ def recreate_all_vs(
|
||||
folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs)
|
||||
|
||||
|
||||
def prune_db_docs(kb_name: str):
|
||||
def prune_db_files(kb_name: str):
|
||||
'''
|
||||
delete docs in database that not existed in local folder.
|
||||
it is used to delete database docs after user deleted some doc files in file browser
|
||||
delete files in database that not existed in local folder.
|
||||
it is used to delete database files after user deleted some doc files in file browser
|
||||
'''
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
if kb.exists():
|
||||
docs_in_db = kb.list_docs()
|
||||
docs_in_folder = list_docs_from_folder(kb_name)
|
||||
docs = list(set(docs_in_db) - set(docs_in_folder))
|
||||
for doc in docs:
|
||||
kb.delete_doc(KnowledgeFile(doc, kb_name))
|
||||
return docs
|
||||
files_in_db = kb.list_files()
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files = list(set(files_in_db) - set(files_in_folder))
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
for kb_file in kb_files:
|
||||
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
|
||||
if kb.vs_type() == SupportedVSType.FAISS:
|
||||
kb.save_vector_store()
|
||||
kb.refresh_vs_cache()
|
||||
return kb_files
|
||||
|
||||
def prune_folder_docs(kb_name: str):
|
||||
def prune_folder_files(kb_name: str):
|
||||
'''
|
||||
delete doc files in local folder that not existed in database.
|
||||
is is used to free local disk space by delete unused doc files.
|
||||
'''
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
if kb.exists():
|
||||
docs_in_db = kb.list_docs()
|
||||
docs_in_folder = list_docs_from_folder(kb_name)
|
||||
docs = list(set(docs_in_folder) - set(docs_in_db))
|
||||
for doc in docs:
|
||||
os.remove(get_file_path(kb_name, doc))
|
||||
return docs
|
||||
files_in_db = kb.list_files()
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files = list(set(files_in_folder) - set(files_in_db))
|
||||
for file in files:
|
||||
os.remove(get_file_path(kb_name, file))
|
||||
return files
|
||||
|
||||
@ -12,6 +12,26 @@ from configs.model_config import (
|
||||
from functools import lru_cache
|
||||
import importlib
|
||||
from text_splitter import zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from pathlib import Path
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
def _embeddings_hash(self):
|
||||
if isinstance(self, HuggingFaceEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, HuggingFaceBgeEmbeddings):
|
||||
return hash(self.model_name)
|
||||
elif isinstance(self, OpenAIEmbeddings):
|
||||
return hash(self.model)
|
||||
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
OpenAIEmbeddings.__hash__ = _embeddings_hash
|
||||
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
@ -20,27 +40,34 @@ def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
|
||||
|
||||
def get_doc_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||
|
||||
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||
|
||||
|
||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||
|
||||
|
||||
def list_kbs_from_folder():
|
||||
return [f for f in os.listdir(KB_ROOT_PATH)
|
||||
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
|
||||
|
||||
def list_docs_from_folder(kb_name: str):
|
||||
|
||||
def list_files_from_folder(kb_name: str):
|
||||
doc_path = get_doc_path(kb_name)
|
||||
return [file for file in os.listdir(doc_path)
|
||||
if os.path.isfile(os.path.join(doc_path, file))]
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
@ -56,16 +83,90 @@ def load_embeddings(model: str, device: str):
|
||||
return embeddings
|
||||
|
||||
|
||||
|
||||
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.doc', '.docx', '.epub', '.odt', '.pdf',
|
||||
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
|
||||
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
||||
"UnstructuredMarkdownLoader": ['.md'],
|
||||
"CustomJSONLoader": [".json"],
|
||||
"CSVLoader": [".csv"],
|
||||
"PyPDFLoader": [".pdf"],
|
||||
"RapidOCRPDFLoader": [".pdf"],
|
||||
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||
'.rtf', '.txt', '.xml',
|
||||
'.doc', '.docx', '.epub', '.odt',
|
||||
'.ppt', '.pptx', '.tsv'], # '.xlsx'
|
||||
}
|
||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||
|
||||
|
||||
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
|
||||
'''
|
||||
langchain的JSONLoader需要jq,在win上使用不便,进行替代。
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
content_key: Optional[str] = None,
|
||||
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
|
||||
text_content: bool = True,
|
||||
json_lines: bool = False,
|
||||
):
|
||||
"""Initialize the JSONLoader.
|
||||
|
||||
Args:
|
||||
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
|
||||
content_key (str): The key to use to extract the content from the JSON if
|
||||
results to a list of objects (dict).
|
||||
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
|
||||
object extracted by the jq_schema and the default metadata and returns
|
||||
a dict of the updated metadata.
|
||||
text_content (bool): Boolean flag to indicate whether the content is in
|
||||
string format, default to True.
|
||||
json_lines (bool): Boolean flag to indicate whether the input is in
|
||||
JSON Lines format.
|
||||
"""
|
||||
self.file_path = Path(file_path).resolve()
|
||||
self._content_key = content_key
|
||||
self._metadata_func = metadata_func
|
||||
self._text_content = text_content
|
||||
self._json_lines = json_lines
|
||||
|
||||
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
|
||||
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
|
||||
def load(self) -> List[Document]:
|
||||
"""Load and return documents from the JSON file."""
|
||||
docs: List[Document] = []
|
||||
if self._json_lines:
|
||||
with self.file_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
self._parse(line, docs)
|
||||
else:
|
||||
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
|
||||
return docs
|
||||
|
||||
def _parse(self, content: str, docs: List[Document]) -> None:
|
||||
"""Convert given content to documents."""
|
||||
data = json.loads(content)
|
||||
|
||||
# Perform some validation
|
||||
# This is not a perfect validation, but it should catch most cases
|
||||
# and prevent the user from getting a cryptic error later on.
|
||||
if self._content_key is not None:
|
||||
self._validate_content_key(data)
|
||||
|
||||
for i, sample in enumerate(data, len(docs) + 1):
|
||||
metadata = dict(
|
||||
source=str(self.file_path),
|
||||
seq_num=i,
|
||||
)
|
||||
text = self._get_text(sample=sample, metadata=metadata)
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
|
||||
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
|
||||
|
||||
|
||||
def get_LoaderClass(file_extension):
|
||||
for LoaderClass, extensions in LOADER_DICT.items():
|
||||
if file_extension in extensions:
|
||||
@ -90,10 +191,16 @@ class KnowledgeFile:
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
self.text_splitter_name = None
|
||||
|
||||
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
|
||||
print(self.document_loader_name)
|
||||
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
|
||||
if self.docs is not None and not refresh:
|
||||
return self.docs
|
||||
|
||||
print(f"{self.document_loader_name} used for {self.filepath}")
|
||||
try:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@ -101,6 +208,16 @@ class KnowledgeFile:
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
if self.document_loader_name == "UnstructuredFileLoader":
|
||||
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
|
||||
elif self.document_loader_name == "CSVLoader":
|
||||
loader = DocumentLoader(self.filepath, encoding="utf-8")
|
||||
elif self.document_loader_name == "JSONLoader":
|
||||
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
|
||||
elif self.document_loader_name == "CustomJSONLoader":
|
||||
loader = DocumentLoader(self.filepath, text_content=False)
|
||||
elif self.document_loader_name == "UnstructuredMarkdownLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
elif self.document_loader_name == "UnstructuredHTMLLoader":
|
||||
loader = DocumentLoader(self.filepath, mode="elements")
|
||||
else:
|
||||
loader = DocumentLoader(self.filepath)
|
||||
|
||||
@ -136,4 +253,63 @@ class KnowledgeFile:
|
||||
print(docs[0])
|
||||
if using_zh_title_enhance:
|
||||
docs = zh_title_enhance(docs)
|
||||
self.docs = docs
|
||||
return docs
|
||||
|
||||
def get_mtime(self):
|
||||
return os.path.getmtime(self.filepath)
|
||||
|
||||
def get_size(self):
|
||||
return os.path.getsize(self.filepath)
|
||||
|
||||
|
||||
def run_in_thread_pool(
|
||||
func: Callable,
|
||||
params: List[Dict] = [],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
||||
'''
|
||||
tasks = []
|
||||
if pool is None:
|
||||
pool = ThreadPoolExecutor()
|
||||
|
||||
for kwargs in params:
|
||||
thread = pool.submit(func, **kwargs)
|
||||
tasks.append(thread)
|
||||
|
||||
for obj in as_completed(tasks):
|
||||
yield obj.result()
|
||||
|
||||
|
||||
def files2docs_in_thread(
|
||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
||||
pool: ThreadPoolExecutor = None,
|
||||
) -> Generator:
|
||||
'''
|
||||
利用多线程批量将文件转化成langchain Document.
|
||||
生成器返回值为{(kb_name, file_name): docs}
|
||||
'''
|
||||
def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
|
||||
try:
|
||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||
except Exception as e:
|
||||
return False, e
|
||||
|
||||
kwargs_list = []
|
||||
for i, file in enumerate(files):
|
||||
kwargs = {}
|
||||
if isinstance(file, tuple) and len(file) >= 2:
|
||||
files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
|
||||
elif isinstance(file, dict):
|
||||
filename = file.pop("filename")
|
||||
kb_name = file.pop("kb_name")
|
||||
files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kwargs = file
|
||||
kwargs["file"] = file
|
||||
kwargs_list.append(kwargs)
|
||||
|
||||
for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
|
||||
yield result
|
||||
|
||||
@ -4,8 +4,8 @@ import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
|
||||
from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
|
||||
|
||||
|
||||
host_ip = "0.0.0.0"
|
||||
@ -15,13 +15,6 @@ openai_api_port = 8888
|
||||
base_url = "http://127.0.0.1:{}"
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
import httpx
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||
|
||||
|
||||
def create_controller_app(
|
||||
dispatch_method="shortest_queue",
|
||||
):
|
||||
@ -41,7 +34,7 @@ def create_model_worker_app(
|
||||
worker_address=base_url.format(model_worker_port),
|
||||
controller_address=base_url.format(controller_port),
|
||||
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
|
||||
device=LLM_DEVICE,
|
||||
device=llm_device(),
|
||||
gpus=None,
|
||||
max_gpu_memory="20GiB",
|
||||
load_8bit=False,
|
||||
|
||||
1
server/model_workers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .zhipu import ChatGLMWorker
|
||||
71
server/model_workers/base.py
Normal file
@ -0,0 +1,71 @@
|
||||
from configs.model_config import LOG_PATH
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import BaseModelWorker
|
||||
import uuid
|
||||
import json
|
||||
import sys
|
||||
from pydantic import BaseModel
|
||||
import fastchat
|
||||
import threading
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
# 恢复被fastchat覆盖的标准输出
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
|
||||
class ApiModelOutMsg(BaseModel):
|
||||
error_code: int = 0
|
||||
text: str
|
||||
|
||||
class ApiModelWorker(BaseModelWorker):
|
||||
BASE_URL: str
|
||||
SUPPORT_MODELS: List
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_names: List[str],
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
context_len: int = 2048,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
|
||||
kwargs.setdefault("model_path", "")
|
||||
kwargs.setdefault("limit_worker_concurrency", 5)
|
||||
super().__init__(model_names=model_names,
|
||||
controller_addr=controller_addr,
|
||||
worker_addr=worker_addr,
|
||||
**kwargs)
|
||||
self.context_len = context_len
|
||||
self.init_heart_beat()
|
||||
|
||||
def count_token(self, params):
|
||||
# TODO:需要完善
|
||||
print("count token")
|
||||
print(params)
|
||||
prompt = params["prompt"]
|
||||
return {"count": len(str(prompt)), "error_code": 0}
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
self.call_ct += 1
|
||||
|
||||
def generate_gate(self, params):
|
||||
for x in self.generate_stream_gate(params):
|
||||
pass
|
||||
return json.loads(x[:-1].decode())
|
||||
|
||||
def get_embeddings(self, params):
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
75
server/model_workers/zhipu.py
Normal file
@ -0,0 +1,75 @@
|
||||
import zhipuai
|
||||
from server.model_workers.base import ApiModelWorker
|
||||
from fastchat import conversation as conv
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Literal
|
||||
|
||||
|
||||
class ChatGLMWorker(ApiModelWorker):
|
||||
BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_names: List[str] = ["chatglm-api"],
|
||||
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
|
||||
controller_addr: str,
|
||||
worker_addr: str,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||
kwargs.setdefault("context_len", 32768)
|
||||
super().__init__(**kwargs)
|
||||
self.version = version
|
||||
|
||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
||||
self.conv = conv.Conversation(
|
||||
name="chatglm-api",
|
||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||
messages=[],
|
||||
roles=["Human", "Assistant"],
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
# TODO: 支持stream参数,维护request_id,传过来的prompt也有问题
|
||||
from server.utils import get_model_worker_config
|
||||
|
||||
super().generate_stream_gate(params)
|
||||
zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
|
||||
|
||||
response = zhipuai.model_api.sse_invoke(
|
||||
model=self.version,
|
||||
prompt=[{"role": "user", "content": params["prompt"]}],
|
||||
temperature=params.get("temperature"),
|
||||
top_p=params.get("top_p"),
|
||||
incremental=False,
|
||||
)
|
||||
for e in response.events():
|
||||
if e.event == "add":
|
||||
yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
|
||||
# TODO: 更健壮的消息处理
|
||||
# elif e.event == "finish":
|
||||
# ...
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
print("embedding")
|
||||
print(params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
from server.utils import MakeFastAPIOffline
|
||||
from fastchat.serve.model_worker import app
|
||||
|
||||
worker = ChatGLMWorker(
|
||||
controller_addr="http://127.0.0.1:20001",
|
||||
worker_addr="http://127.0.0.1:20003",
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=20003)
|
||||
123
server/utils.py
@ -1,16 +1,20 @@
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
import torch
|
||||
from fastapi import FastAPI
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
from typing import Any, Optional
|
||||
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
import os
|
||||
from server import model_workers
|
||||
from typing import Literal, Optional, Any
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="API status code")
|
||||
msg: str = pydantic.Field("success", description="API status message")
|
||||
data: Any = pydantic.Field(None, description="API data")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
@ -68,6 +72,7 @@ class ChatMessage(BaseModel):
|
||||
}
|
||||
|
||||
def torch_gc():
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
# with torch.cuda.device(DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
@ -186,3 +191,117 @@ def MakeFastAPIOffline(
|
||||
with_google_fonts=False,
|
||||
redoc_favicon_url=favicon,
|
||||
)
|
||||
|
||||
|
||||
# 从server_config中获取服务信息
|
||||
def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
|
||||
'''
|
||||
加载model worker的配置项。
|
||||
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
||||
'''
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from configs.model_config import llm_model_dict
|
||||
|
||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||
config.update(llm_model_dict.get(model_name, {}))
|
||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||
|
||||
# 如果没有设置有效的local_model_path,则认为是在线模型API
|
||||
if not os.path.isdir(config.get("local_model_path", "")):
|
||||
config["online_api"] = True
|
||||
if provider := config.get("provider"):
|
||||
try:
|
||||
config["worker_class"] = getattr(model_workers, provider)
|
||||
except Exception as e:
|
||||
print(f"在线模型 ‘{model_name}’ 的provider没有正确配置")
|
||||
|
||||
config["device"] = llm_device(config.get("device") or LLM_DEVICE)
|
||||
return config
|
||||
|
||||
|
||||
def get_all_model_worker_configs() -> dict:
|
||||
result = {}
|
||||
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
|
||||
for name in model_names:
|
||||
if name != "default":
|
||||
result[name] = get_model_worker_config(name)
|
||||
return result
|
||||
|
||||
|
||||
def fschat_controller_address() -> str:
|
||||
from configs.server_config import FSCHAT_CONTROLLER
|
||||
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
|
||||
if model := get_model_worker_config(model_name):
|
||||
host = model["host"]
|
||||
port = model["port"]
|
||||
return f"http://{host}:{port}"
|
||||
return ""
|
||||
|
||||
|
||||
def fschat_openai_api_address() -> str:
|
||||
from configs.server_config import FSCHAT_OPENAI_API
|
||||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def api_address() -> str:
|
||||
from configs.server_config import API_SERVER
|
||||
|
||||
host = API_SERVER["host"]
|
||||
port = API_SERVER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def webui_address() -> str:
|
||||
from configs.server_config import WEBUI_SERVER
|
||||
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout: float = None):
|
||||
'''
|
||||
设置httpx默认timeout。
|
||||
httpx默认timeout是5秒,在请求LLM回答时不够用。
|
||||
'''
|
||||
import httpx
|
||||
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
|
||||
|
||||
timeout = timeout or HTTPX_DEFAULT_TIMEOUT
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||
|
||||
|
||||
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
except:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
|
||||
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
|
||||
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
||||
if device not in ["cuda", "mps", "cpu"]:
|
||||
device = detect_device()
|
||||
return device
|
||||
|
||||
517
startup.py
@ -1,59 +1,60 @@
|
||||
from multiprocessing import Process, Queue
|
||||
import asyncio
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from multiprocessing import Process, Queue
|
||||
from pprint import pprint
|
||||
|
||||
# 设置numexpr最大线程数,默认为CPU核心数
|
||||
try:
|
||||
import numexpr
|
||||
|
||||
n_cores = numexpr.utils.detect_number_of_cores()
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
|
||||
except:
|
||||
pass
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
|
||||
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
||||
logger
|
||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
|
||||
FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,
|
||||
fschat_openai_api_address, )
|
||||
from server.utils import MakeFastAPIOffline, FastAPI
|
||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
||||
FSCHAT_OPENAI_API, )
|
||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||
fschat_openai_api_address, set_httpx_timeout,
|
||||
get_model_worker_config, get_all_model_worker_configs,
|
||||
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
||||
import argparse
|
||||
from typing import Tuple, List
|
||||
from typing import Tuple, List, Dict
|
||||
from configs import VERSION
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
import httpx
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||
|
||||
|
||||
def create_controller_app(
|
||||
dispatch_method: str,
|
||||
log_level: str = "INFO",
|
||||
) -> FastAPI:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.controller import app, Controller
|
||||
from fastchat.serve.controller import app, Controller, logger
|
||||
logger.setLevel(log_level)
|
||||
|
||||
controller = Controller(dispatch_method)
|
||||
sys.modules["fastchat.serve.controller"].controller = controller
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat Controller"
|
||||
app._controller = controller
|
||||
return app
|
||||
|
||||
|
||||
def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
|
||||
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
|
||||
import argparse
|
||||
import threading
|
||||
import fastchat.serve.model_worker
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
@ -99,53 +100,76 @@ def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
# 在线模型API
|
||||
if worker_class := kwargs.get("worker_class"):
|
||||
worker = worker_class(model_names=args.model_names,
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address)
|
||||
# 本地模型
|
||||
else:
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def _new_init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
ModelWorker.init_heart_beat = _new_init_heart_beat
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({LLM_MODEL})"
|
||||
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
||||
app._worker = worker
|
||||
return app
|
||||
|
||||
|
||||
def create_openai_api_app(
|
||||
controller_address: str,
|
||||
api_keys: List = [],
|
||||
log_level: str = "INFO",
|
||||
) -> FastAPI:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
||||
from fastchat.utils import build_logger
|
||||
logger = build_logger("openai_api", "openai_api.log")
|
||||
logger.setLevel(log_level)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@ -155,6 +179,7 @@ def create_openai_api_app(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
sys.modules["fastchat.serve.openai_api_server"].logger = logger
|
||||
app_settings.controller_address = controller_address
|
||||
app_settings.api_keys = api_keys
|
||||
|
||||
@ -164,6 +189,9 @@ def create_openai_api_app(
|
||||
|
||||
|
||||
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
||||
if q is None or not isinstance(run_seq, int):
|
||||
return
|
||||
|
||||
if run_seq == 1:
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
@ -182,15 +210,90 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
||||
q.put(run_seq)
|
||||
|
||||
|
||||
def run_controller(q: Queue, run_seq: int = 1):
|
||||
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
|
||||
import uvicorn
|
||||
import httpx
|
||||
from fastapi import Body
|
||||
import time
|
||||
import sys
|
||||
|
||||
app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method"))
|
||||
app = create_controller_app(
|
||||
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
|
||||
log_level=log_level,
|
||||
)
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
@app.on_event("startup")
|
||||
def on_startup():
|
||||
if e is not None:
|
||||
e.set()
|
||||
|
||||
# add interface to release and load model worker
|
||||
@app.post("/release_worker")
|
||||
def release_worker(
|
||||
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
|
||||
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
|
||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||
) -> Dict:
|
||||
available_models = app._controller.list_models()
|
||||
if new_model_name in available_models:
|
||||
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
|
||||
logger.info(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
if new_model_name:
|
||||
logger.info(f"开始切换LLM模型:从 {model_name} 到 {new_model_name}")
|
||||
else:
|
||||
logger.info(f"即将停止LLM模型: {model_name}")
|
||||
|
||||
if model_name not in available_models:
|
||||
msg = f"the model {model_name} is not available"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
worker_address = app._controller.get_worker_address(model_name)
|
||||
if not worker_address:
|
||||
msg = f"can not find model_worker address for {model_name}"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
r = httpx.post(worker_address + "/release",
|
||||
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
|
||||
if r.status_code != 200:
|
||||
msg = f"failed to release model: {model_name}"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
|
||||
if new_model_name:
|
||||
timer = 300 # wait 5 minutes for new model_worker register
|
||||
while timer > 0:
|
||||
models = app._controller.list_models()
|
||||
if new_model_name in models:
|
||||
break
|
||||
time.sleep(1)
|
||||
timer -= 1
|
||||
if timer > 0:
|
||||
msg = f"sucess change model from {model_name} to {new_model_name}"
|
||||
logger.info(msg)
|
||||
return {"code": 200, "msg": msg}
|
||||
else:
|
||||
msg = f"failed change model from {model_name} to {new_model_name}"
|
||||
logger.error(msg)
|
||||
return {"code": 500, "msg": msg}
|
||||
else:
|
||||
msg = f"sucess to release model: {model_name}"
|
||||
logger.info(msg)
|
||||
return {"code": 200, "msg": msg}
|
||||
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
if log_level == "ERROR":
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||
|
||||
|
||||
def run_model_worker(
|
||||
@ -198,33 +301,59 @@ def run_model_worker(
|
||||
controller_address: str = "",
|
||||
q: Queue = None,
|
||||
run_seq: int = 2,
|
||||
log_level: str = "INFO",
|
||||
):
|
||||
import uvicorn
|
||||
from fastapi import Body
|
||||
import sys
|
||||
|
||||
kwargs = FSCHAT_MODEL_WORKERS[model_name].copy()
|
||||
kwargs = get_model_worker_config(model_name)
|
||||
host = kwargs.pop("host")
|
||||
port = kwargs.pop("port")
|
||||
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
kwargs["model_names"] = [model_name]
|
||||
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
||||
kwargs["worker_address"] = fschat_model_worker_address()
|
||||
kwargs["worker_address"] = fschat_model_worker_address(model_name)
|
||||
model_path = kwargs.get("local_model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
|
||||
app = create_model_worker_app(**kwargs)
|
||||
app = create_model_worker_app(log_level=log_level, **kwargs)
|
||||
_set_app_seq(app, q, run_seq)
|
||||
if log_level == "ERROR":
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
# add interface to release and load model
|
||||
@app.post("/release")
|
||||
def release_model(
|
||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||
) -> Dict:
|
||||
if keep_origin:
|
||||
if new_model_name:
|
||||
q.put(["start", new_model_name])
|
||||
else:
|
||||
if new_model_name:
|
||||
q.put(["replace", new_model_name])
|
||||
else:
|
||||
q.put(["stop"])
|
||||
return {"code": 200, "msg": "done"}
|
||||
|
||||
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
|
||||
|
||||
|
||||
def run_openai_api(q: Queue, run_seq: int = 3):
|
||||
def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
|
||||
import uvicorn
|
||||
import sys
|
||||
|
||||
controller_addr = fschat_controller_address()
|
||||
app = create_openai_api_app(controller_addr) # todo: not support keys yet.
|
||||
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
if log_level == "ERROR":
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
@ -244,13 +373,15 @@ def run_api_server(q: Queue, run_seq: int = 4):
|
||||
def run_webui(q: Queue, run_seq: int = 5):
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
while True:
|
||||
no = q.get()
|
||||
if no != run_seq - 1:
|
||||
q.put(no)
|
||||
else:
|
||||
break
|
||||
q.put(run_seq)
|
||||
|
||||
if q is not None and isinstance(run_seq, int):
|
||||
while True:
|
||||
no = q.get()
|
||||
if no != run_seq - 1:
|
||||
q.put(no)
|
||||
else:
|
||||
break
|
||||
q.put(run_seq)
|
||||
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
||||
"--server.address", host,
|
||||
"--server.port", str(port)])
|
||||
@ -313,6 +444,13 @@ def parse_args() -> argparse.ArgumentParser:
|
||||
help="run api.py server",
|
||||
dest="api",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--api-worker",
|
||||
action="store_true",
|
||||
help="run online model api such as zhipuai",
|
||||
dest="api_worker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--webui",
|
||||
@ -320,26 +458,38 @@ def parse_args() -> argparse.ArgumentParser:
|
||||
help="run webui.py server",
|
||||
dest="webui",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
help="减少fastchat服务log信息",
|
||||
dest="quiet",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
return args, parser
|
||||
|
||||
|
||||
def dump_server_info(after_start=False):
|
||||
def dump_server_info(after_start=False, args=None):
|
||||
import platform
|
||||
import langchain
|
||||
import fastchat
|
||||
from configs.server_config import api_address, webui_address
|
||||
from server.utils import api_address, webui_address
|
||||
|
||||
print("\n\n")
|
||||
print("\n")
|
||||
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
||||
print(f"操作系统:{platform.platform()}.")
|
||||
print(f"python版本:{sys.version}")
|
||||
print(f"项目版本:{VERSION}")
|
||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||
print("\n")
|
||||
print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}")
|
||||
pprint(llm_model_dict[LLM_MODEL])
|
||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
|
||||
|
||||
model = LLM_MODEL
|
||||
if args and args.model_name:
|
||||
model = args.model_name
|
||||
print(f"当前LLM模型:{model} @ {llm_device()}")
|
||||
pprint(llm_model_dict[model])
|
||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||
|
||||
if after_start:
|
||||
print("\n")
|
||||
print(f"服务端运行信息:")
|
||||
@ -351,110 +501,231 @@ def dump_server_info(after_start=False):
|
||||
if args.webui:
|
||||
print(f" Chatchat WEBUI Server: {webui_address()}")
|
||||
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
|
||||
print("\n\n")
|
||||
print("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def start_main_server():
|
||||
import time
|
||||
import signal
|
||||
|
||||
def handler(signalname):
|
||||
"""
|
||||
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
|
||||
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
|
||||
"""
|
||||
def f(signal_received, frame):
|
||||
raise KeyboardInterrupt(f"{signalname} received")
|
||||
return f
|
||||
|
||||
# This will be inherited by the child process if it is forked (not spawned)
|
||||
signal.signal(signal.SIGINT, handler("SIGINT"))
|
||||
signal.signal(signal.SIGTERM, handler("SIGTERM"))
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
queue = Queue()
|
||||
args = parse_args()
|
||||
manager = mp.Manager()
|
||||
|
||||
queue = manager.Queue()
|
||||
args, parser = parse_args()
|
||||
|
||||
if args.all_webui:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = True
|
||||
args.api_worker = True
|
||||
args.webui = True
|
||||
|
||||
elif args.all_api:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = True
|
||||
args.api_worker = True
|
||||
args.webui = False
|
||||
|
||||
elif args.llm_api:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api_worker = True
|
||||
args.api = False
|
||||
args.webui = False
|
||||
|
||||
dump_server_info()
|
||||
logger.info(f"正在启动服务:")
|
||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||
dump_server_info(args=args)
|
||||
|
||||
processes = {}
|
||||
if len(sys.argv) > 1:
|
||||
logger.info(f"正在启动服务:")
|
||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||
|
||||
processes = {"online-api": []}
|
||||
|
||||
def process_count():
|
||||
return len(processes) + len(processes["online-api"]) - 1
|
||||
|
||||
if args.quiet:
|
||||
log_level = "ERROR"
|
||||
else:
|
||||
log_level = "INFO"
|
||||
|
||||
controller_started = manager.Event()
|
||||
if args.openai_api:
|
||||
process = Process(
|
||||
target=run_controller,
|
||||
name=f"controller({os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
name=f"controller",
|
||||
args=(queue, process_count() + 1, log_level, controller_started),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
|
||||
processes["controller"] = process
|
||||
|
||||
process = Process(
|
||||
target=run_openai_api,
|
||||
name=f"openai_api({os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
name=f"openai_api",
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["openai_api"] = process
|
||||
|
||||
if args.model_worker:
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker({os.getpid()})",
|
||||
args=(args.model_name, args.controller_address, queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["model_worker"] = process
|
||||
config = get_model_worker_config(args.model_name)
|
||||
if not config.get("online_api"):
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {args.model_name}",
|
||||
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
processes["model_worker"] = process
|
||||
|
||||
if args.api_worker:
|
||||
configs = get_all_model_worker_configs()
|
||||
for model_name, config in configs.items():
|
||||
if config.get("online_api") and config.get("worker_class"):
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker - {model_name}",
|
||||
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
processes["online-api"].append(process)
|
||||
|
||||
if args.api:
|
||||
process = Process(
|
||||
target=run_api_server,
|
||||
name=f"API Server{os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
name=f"API Server",
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
|
||||
processes["api"] = process
|
||||
|
||||
if args.webui:
|
||||
process = Process(
|
||||
target=run_webui,
|
||||
name=f"WEBUI Server{os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
name=f"WEBUI Server",
|
||||
args=(queue, process_count() + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
|
||||
processes["webui"] = process
|
||||
|
||||
try:
|
||||
# log infors
|
||||
while True:
|
||||
no = queue.get()
|
||||
if no == len(processes):
|
||||
time.sleep(0.5)
|
||||
dump_server_info(True)
|
||||
break
|
||||
else:
|
||||
queue.put(no)
|
||||
if process_count() == 0:
|
||||
parser.print_help()
|
||||
else:
|
||||
try:
|
||||
# 保证任务收到SIGINT后,能够正常退出
|
||||
if p:= processes.get("controller"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
controller_started.wait()
|
||||
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
model_worker_process.join()
|
||||
for name, process in processes.items():
|
||||
if name != "model_worker":
|
||||
if p:= processes.get("openai_api"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
if p:= processes.get("model_worker"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
for p in processes.get("online-api", []):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
if p:= processes.get("api"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
if p:= processes.get("webui"):
|
||||
p.start()
|
||||
p.name = f"{p.name} ({p.pid})"
|
||||
|
||||
while True:
|
||||
no = queue.get()
|
||||
if no == process_count():
|
||||
time.sleep(0.5)
|
||||
dump_server_info(after_start=True, args=args)
|
||||
break
|
||||
else:
|
||||
queue.put(no)
|
||||
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
model_worker_process.join()
|
||||
for process in processes.get("online-api", []):
|
||||
process.join()
|
||||
except:
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
model_worker_process.terminate()
|
||||
for name, process in processes.items():
|
||||
if name != "model_worker":
|
||||
process.terminate()
|
||||
for name, process in processes.items():
|
||||
if name not in ["model_worker", "online-api"]:
|
||||
if isinstance(p, list):
|
||||
for work_process in p:
|
||||
work_process.join()
|
||||
else:
|
||||
process.join()
|
||||
except Exception as e:
|
||||
# if model_worker_process := processes.pop("model_worker", None):
|
||||
# model_worker_process.terminate()
|
||||
# for process in processes.pop("online-api", []):
|
||||
# process.terminate()
|
||||
# for process in processes.values():
|
||||
#
|
||||
# if isinstance(process, list):
|
||||
# for work_process in process:
|
||||
# work_process.terminate()
|
||||
# else:
|
||||
# process.terminate()
|
||||
logger.error(e)
|
||||
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
|
||||
finally:
|
||||
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
|
||||
# .is_alive() also implicitly joins the process (good practice in linux)
|
||||
# while alive_procs := [p for p in processes.values() if p.is_alive()]:
|
||||
|
||||
for p in processes.values():
|
||||
logger.warning("Sending SIGKILL to %s", p)
|
||||
# Queues and other inter-process communication primitives can break when
|
||||
# process is killed, but we don't care here
|
||||
|
||||
if isinstance(p, list):
|
||||
for process in p:
|
||||
process.kill()
|
||||
|
||||
else:
|
||||
p.kill()
|
||||
|
||||
for p in processes.values():
|
||||
logger.info("Process status: %s", p)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
# 同步调用协程代码
|
||||
loop.run_until_complete(start_main_server())
|
||||
|
||||
|
||||
# 服务启动后接口调用示例:
|
||||
# import openai
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from doctest import testfile
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
@ -6,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from configs.server_config import api_address
|
||||
from server.utils import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path
|
||||
|
||||
@ -112,7 +111,7 @@ def test_upload_doc(api="/knowledge_base/upload_doc"):
|
||||
assert data["msg"] == f"成功上传文件 {name}"
|
||||
|
||||
|
||||
def test_list_docs(api="/knowledge_base/list_docs"):
|
||||
def test_list_files(api="/knowledge_base/list_files"):
|
||||
url = api_base_url + api
|
||||
print("\n获取知识库中文件列表:")
|
||||
r = requests.get(url, params={"knowledge_base_name": kb})
|
||||
|
||||
74
tests/api/test_llm_api.py
Normal file
@ -0,0 +1,74 @@
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
|
||||
from configs.model_config import LLM_MODEL, llm_model_dict
|
||||
|
||||
from pprint import pprint
|
||||
import random
|
||||
|
||||
|
||||
def get_configured_models():
|
||||
model_workers = list(FSCHAT_MODEL_WORKERS)
|
||||
if "default" in model_workers:
|
||||
model_workers.remove("default")
|
||||
|
||||
llm_dict = list(llm_model_dict)
|
||||
|
||||
return model_workers, llm_dict
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
|
||||
|
||||
def get_running_models(api="/llm_model/list_models"):
|
||||
url = api_base_url + api
|
||||
r = requests.post(url)
|
||||
if r.status_code == 200:
|
||||
return r.json()["data"]
|
||||
return []
|
||||
|
||||
|
||||
def test_running_models(api="/llm_model/list_models"):
|
||||
url = api_base_url + api
|
||||
r = requests.post(url)
|
||||
assert r.status_code == 200
|
||||
print("\n获取当前正在运行的模型列表:")
|
||||
pprint(r.json())
|
||||
assert isinstance(r.json()["data"], list)
|
||||
assert len(r.json()["data"]) > 0
|
||||
|
||||
|
||||
# 不建议使用stop_model功能。按现在的实现,停止了就只能手动再启动
|
||||
# def test_stop_model(api="/llm_model/stop"):
|
||||
# url = api_base_url + api
|
||||
# r = requests.post(url, json={""})
|
||||
|
||||
|
||||
def test_change_model(api="/llm_model/change"):
|
||||
url = api_base_url + api
|
||||
|
||||
running_models = get_running_models()
|
||||
assert len(running_models) > 0
|
||||
|
||||
model_workers, llm_dict = get_configured_models()
|
||||
|
||||
availabel_new_models = set(model_workers) - set(running_models)
|
||||
if len(availabel_new_models) == 0:
|
||||
availabel_new_models = set(llm_dict) - set(running_models)
|
||||
availabel_new_models = list(availabel_new_models)
|
||||
assert len(availabel_new_models) > 0
|
||||
print(availabel_new_models)
|
||||
|
||||
model_name = random.choice(running_models)
|
||||
new_model_name = random.choice(availabel_new_models)
|
||||
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
|
||||
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
|
||||
assert r.status_code == 200
|
||||
|
||||
running_models = get_running_models()
|
||||
assert new_model_name in running_models
|
||||
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||
from configs.model_config import BING_SUBSCRIPTION_KEY
|
||||
from configs.server_config import API_SERVER, api_address
|
||||
from server.utils import api_address
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
21
tests/document_loader/test_imgloader.py
Normal file
@ -0,0 +1,21 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from pprint import pprint
|
||||
|
||||
test_files = {
|
||||
"ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"),
|
||||
}
|
||||
|
||||
def test_rapidocrpdfloader():
|
||||
pdf_path = test_files["ocr_test.pdf"]
|
||||
from document_loaders import RapidOCRPDFLoader
|
||||
|
||||
loader = RapidOCRPDFLoader(pdf_path)
|
||||
docs = loader.load()
|
||||
pprint(docs)
|
||||
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)
|
||||
|
||||
|
||||
21
tests/document_loader/test_pdfloader.py
Normal file
@ -0,0 +1,21 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from pprint import pprint
|
||||
|
||||
test_files = {
|
||||
"ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"),
|
||||
}
|
||||
|
||||
def test_rapidocrloader():
|
||||
img_path = test_files["ocr_test.jpg"]
|
||||
from document_loaders import RapidOCRLoader
|
||||
|
||||
loader = RapidOCRLoader(img_path)
|
||||
docs = loader.load()
|
||||
pprint(docs)
|
||||
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)
|
||||
|
||||
|
||||
0
tests/kb_vector_db/__init__.py
Normal file
35
tests/kb_vector_db/test_faiss_kb.py
Normal file
@ -0,0 +1,35 @@
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
kbService = FaissKBService("test")
|
||||
test_kb_name = "test"
|
||||
test_file_name = "README.md"
|
||||
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
|
||||
search_content = "如何启动api服务"
|
||||
|
||||
|
||||
def test_init():
|
||||
create_tables()
|
||||
|
||||
|
||||
def test_create_db():
|
||||
assert kbService.create_kb()
|
||||
|
||||
|
||||
def test_add_doc():
|
||||
assert kbService.add_doc(testKnowledgeFile)
|
||||
|
||||
|
||||
def test_search_db():
|
||||
result = kbService.search_docs(search_content)
|
||||
assert len(result) > 0
|
||||
def test_delete_doc():
|
||||
assert kbService.delete_doc(testKnowledgeFile)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def test_delete_db():
|
||||
assert kbService.drop_kb()
|
||||
31
tests/kb_vector_db/test_milvus_db.py
Normal file
@ -0,0 +1,31 @@
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
kbService = MilvusKBService("test")
|
||||
|
||||
test_kb_name = "test"
|
||||
test_file_name = "README.md"
|
||||
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
|
||||
search_content = "如何启动api服务"
|
||||
|
||||
def test_init():
|
||||
create_tables()
|
||||
|
||||
|
||||
def test_create_db():
|
||||
assert kbService.create_kb()
|
||||
|
||||
|
||||
def test_add_doc():
|
||||
assert kbService.add_doc(testKnowledgeFile)
|
||||
|
||||
|
||||
def test_search_db():
|
||||
result = kbService.search_docs(search_content)
|
||||
assert len(result) > 0
|
||||
def test_delete_doc():
|
||||
assert kbService.delete_doc(testKnowledgeFile)
|
||||
|
||||
31
tests/kb_vector_db/test_pg_db.py
Normal file
@ -0,0 +1,31 @@
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||
from server.knowledge_base.migrate import create_tables
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
kbService = PGKBService("test")
|
||||
|
||||
test_kb_name = "test"
|
||||
test_file_name = "README.md"
|
||||
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
|
||||
search_content = "如何启动api服务"
|
||||
|
||||
|
||||
def test_init():
|
||||
create_tables()
|
||||
|
||||
|
||||
def test_create_db():
|
||||
assert kbService.create_kb()
|
||||
|
||||
|
||||
def test_add_doc():
|
||||
assert kbService.add_doc(testKnowledgeFile)
|
||||
|
||||
|
||||
def test_search_db():
|
||||
result = kbService.search_docs(search_content)
|
||||
assert len(result) > 0
|
||||
def test_delete_doc():
|
||||
assert kbService.delete_doc(testKnowledgeFile)
|
||||
|
||||
BIN
tests/samples/ocr_test.jpg
Normal file
|
After Width: | Height: | Size: 7.9 KiB |
BIN
tests/samples/ocr_test.pdf
Normal file
4
webui.py
@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu
|
||||
from webui_pages import *
|
||||
import os
|
||||
from configs import VERSION
|
||||
from server.utils import api_address
|
||||
|
||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
||||
|
||||
api = ApiRequest(base_url=api_address())
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.set_page_config(
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
import streamlit as st
|
||||
from configs.server_config import FSCHAT_MODEL_WORKERS
|
||||
from webui_pages.utils import *
|
||||
from streamlit_chatbox import *
|
||||
from datetime import datetime
|
||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
||||
from typing import List, Dict
|
||||
import os
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||
from server.utils import get_model_worker_config
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
@ -59,9 +63,39 @@ def dialogue_page(api: ApiRequest):
|
||||
on_change=on_mode_change,
|
||||
key="dialogue_mode",
|
||||
)
|
||||
history_len = st.number_input("历史对话轮数:", 0, 10, 3)
|
||||
|
||||
# todo: support history len
|
||||
def on_llm_change():
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
def llm_model_format_func(x):
|
||||
if x in running_models:
|
||||
return f"{x} (Running)"
|
||||
return x
|
||||
|
||||
running_models = api.list_running_models()
|
||||
config_models = api.list_config_models()
|
||||
for x in running_models:
|
||||
if x in config_models:
|
||||
config_models.remove(x)
|
||||
llm_models = running_models + config_models
|
||||
if "prev_llm_model" not in st.session_state:
|
||||
index = llm_models.index(LLM_MODEL)
|
||||
else:
|
||||
index = 0
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
llm_models,
|
||||
index,
|
||||
format_func=llm_model_format_func,
|
||||
on_change=on_llm_change,
|
||||
# key="llm_model",
|
||||
)
|
||||
if (st.session_state.get("prev_llm_model") != llm_model
|
||||
and not get_model_worker_config(llm_model).get("online_api")):
|
||||
with st.spinner(f"正在加载模型: {llm_model}"):
|
||||
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
|
||||
|
||||
def on_kb_change():
|
||||
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
|
||||
@ -75,7 +109,7 @@ def dialogue_page(api: ApiRequest):
|
||||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
|
||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||
@ -87,13 +121,13 @@ def dialogue_page(api: ApiRequest):
|
||||
options=search_engine_list,
|
||||
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
|
||||
)
|
||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
|
||||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
|
||||
chat_box.output_messages()
|
||||
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter "
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter "
|
||||
|
||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||
history = get_messages_history(history_len)
|
||||
@ -101,7 +135,7 @@ def dialogue_page(api: ApiRequest):
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.chat_chat(prompt, history)
|
||||
r = api.chat_chat(prompt, history=history, model=llm_model)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
@ -116,7 +150,7 @@ def dialogue_page(api: ApiRequest):
|
||||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
@ -129,8 +163,8 @@ def dialogue_page(api: ApiRequest):
|
||||
Markdown("...", in_expander=True, title="网络搜索结果"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
else:
|
||||
text += d["answer"]
|
||||
|
||||
@ -4,7 +4,7 @@ from st_aggrid import AgGrid, JsCode
|
||||
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
||||
import pandas as pd
|
||||
from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details
|
||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
||||
from typing import Literal, Dict, Tuple
|
||||
from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE
|
||||
import os
|
||||
@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
|
||||
|
||||
# 上传文件
|
||||
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||
files = st.file_uploader("上传知识文件",
|
||||
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
|
||||
[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True,
|
||||
)
|
||||
@ -152,7 +152,7 @@ def knowledge_base_page(api: ApiRequest):
|
||||
|
||||
# 知识库详情
|
||||
# st.info("请选择文件,点击按钮进行操作。")
|
||||
doc_details = pd.DataFrame(get_kb_doc_details(kb))
|
||||
doc_details = pd.DataFrame(get_kb_file_details(kb))
|
||||
if not len(doc_details):
|
||||
st.info(f"知识库 `{kb}` 中暂无文件")
|
||||
else:
|
||||
@ -160,7 +160,7 @@ def knowledge_base_page(api: ApiRequest):
|
||||
st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作")
|
||||
doc_details.drop(columns=["kb_name"], inplace=True)
|
||||
doc_details = doc_details[[
|
||||
"No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db",
|
||||
"No", "file_name", "document_loader", "docs_count", "in_folder", "in_db",
|
||||
]]
|
||||
# doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×")
|
||||
# doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×")
|
||||
@ -172,7 +172,8 @@ def knowledge_base_page(api: ApiRequest):
|
||||
# ("file_ext", "文档类型"): {},
|
||||
# ("file_version", "文档版本"): {},
|
||||
("document_loader", "文档加载器"): {},
|
||||
("text_splitter", "分词器"): {},
|
||||
("docs_count", "文档数量"): {},
|
||||
# ("text_splitter", "分词器"): {},
|
||||
# ("create_time", "创建时间"): {},
|
||||
("in_folder", "源文件"): {"cellRenderer": cell_renderer},
|
||||
("in_db", "向量库"): {"cellRenderer": cell_renderer},
|
||||
@ -244,7 +245,6 @@ def knowledge_base_page(api: ApiRequest):
|
||||
|
||||
cols = st.columns(3)
|
||||
|
||||
# todo: freezed
|
||||
if cols[0].button(
|
||||
"依据源文件重建向量库",
|
||||
# help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
|
||||
@ -258,7 +258,7 @@ def knowledge_base_page(api: ApiRequest):
|
||||
if msg := check_error_msg(d):
|
||||
st.toast(msg)
|
||||
else:
|
||||
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
|
||||
empty.progress(d["finished"] / d["total"], d["msg"])
|
||||
st.experimental_rerun()
|
||||
|
||||
if cols[2].button(
|
||||
|
||||
@ -6,11 +6,14 @@ from configs.model_config import (
|
||||
DEFAULT_VS_TYPE,
|
||||
KB_ROOT_PATH,
|
||||
LLM_MODEL,
|
||||
llm_model_dict,
|
||||
HISTORY_LEN,
|
||||
SCORE_THRESHOLD,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
logger,
|
||||
)
|
||||
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
|
||||
import httpx
|
||||
import asyncio
|
||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
||||
@ -20,21 +23,12 @@ import json
|
||||
from io import BytesIO
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from server.utils import run_async, iter_over_async
|
||||
from server.utils import run_async, iter_over_async, set_httpx_timeout
|
||||
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
import nltk
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
'''
|
||||
设置httpx默认timeout到60秒。
|
||||
httpx默认timeout是5秒,在请求LLM回答时不够用。
|
||||
'''
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
KB_ROOT_PATH = Path(KB_ROOT_PATH)
|
||||
@ -50,7 +44,7 @@ class ApiRequest:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://127.0.0.1:7861",
|
||||
timeout: float = 60.0,
|
||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||
no_remote_api: bool = False, # call api view function directly
|
||||
):
|
||||
self.base_url = base_url
|
||||
@ -224,9 +218,17 @@ class ApiRequest:
|
||||
try:
|
||||
with response as r:
|
||||
for chunk in r.iter_text(None):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
elif chunk.strip():
|
||||
if not chunk: # fastchat api yield empty bytes on start and end
|
||||
continue
|
||||
if as_json:
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
pprint(data, depth=1)
|
||||
yield data
|
||||
except Exception as e:
|
||||
logger.error(f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。")
|
||||
else:
|
||||
print(chunk, end="", flush=True)
|
||||
yield chunk
|
||||
except httpx.ConnectError as e:
|
||||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。"
|
||||
@ -274,6 +276,9 @@ class ApiRequest:
|
||||
return self._fastapi_stream2generator(response)
|
||||
else:
|
||||
data = msg.dict(exclude_unset=True, exclude_none=True)
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
|
||||
response = self.post(
|
||||
"/chat/fastchat",
|
||||
json=data,
|
||||
@ -286,6 +291,7 @@ class ApiRequest:
|
||||
query: str,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
@ -298,8 +304,12 @@ class ApiRequest:
|
||||
"query": query,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
|
||||
if no_remote_api:
|
||||
from server.chat.chat import chat
|
||||
response = chat(**data)
|
||||
@ -316,6 +326,7 @@ class ApiRequest:
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
@ -331,9 +342,13 @@ class ApiRequest:
|
||||
"score_threshold": score_threshold,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"local_doc_url": no_remote_api,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
|
||||
if no_remote_api:
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
response = knowledge_base_chat(**data)
|
||||
@ -352,6 +367,7 @@ class ApiRequest:
|
||||
search_engine_name: str,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
@ -365,8 +381,12 @@ class ApiRequest:
|
||||
"search_engine_name": search_engine_name,
|
||||
"top_k": top_k,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
pprint(data)
|
||||
|
||||
if no_remote_api:
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
response = search_engine_chat(**data)
|
||||
@ -473,18 +493,18 @@ class ApiRequest:
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/list_docs接口
|
||||
对应api.py/knowledge_base/list_files接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import list_docs
|
||||
response = run_async(list_docs(knowledge_base_name))
|
||||
from server.knowledge_base.kb_doc_api import list_files
|
||||
response = run_async(list_files(knowledge_base_name))
|
||||
return response.data
|
||||
else:
|
||||
response = self.get(
|
||||
"/knowledge_base/list_docs",
|
||||
"/knowledge_base/list_files",
|
||||
params={"knowledge_base_name": knowledge_base_name}
|
||||
)
|
||||
data = self._check_httpx_json_response(response)
|
||||
@ -633,6 +653,84 @@ class ApiRequest:
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
def list_running_models(self, controller_address: str = None):
|
||||
'''
|
||||
获取Fastchat中正运行的模型列表
|
||||
'''
|
||||
r = self.post(
|
||||
"/llm_model/list_models",
|
||||
)
|
||||
return r.json().get("data", [])
|
||||
|
||||
def list_config_models(self):
|
||||
'''
|
||||
获取configs中配置的模型列表
|
||||
'''
|
||||
return list(llm_model_dict.keys())
|
||||
|
||||
def stop_llm_model(
|
||||
self,
|
||||
model_name: str,
|
||||
controller_address: str = None,
|
||||
):
|
||||
'''
|
||||
停止某个LLM模型。
|
||||
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
||||
'''
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
r = self.post(
|
||||
"/llm_model/stop",
|
||||
json=data,
|
||||
)
|
||||
return r.json()
|
||||
|
||||
def change_llm_model(
|
||||
self,
|
||||
model_name: str,
|
||||
new_model_name: str,
|
||||
controller_address: str = None,
|
||||
):
|
||||
'''
|
||||
向fastchat controller请求切换LLM模型。
|
||||
'''
|
||||
if not model_name or not new_model_name:
|
||||
return
|
||||
|
||||
if new_model_name == model_name:
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "什么都不用做"
|
||||
}
|
||||
|
||||
running_models = self.list_running_models()
|
||||
if model_name not in running_models:
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
|
||||
}
|
||||
|
||||
config_models = self.list_config_models()
|
||||
if new_model_name not in config_models:
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"new_model_name": new_model_name,
|
||||
"controller_address": controller_address,
|
||||
}
|
||||
r = self.post(
|
||||
"/llm_model/change",
|
||||
json=data,
|
||||
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
||||
)
|
||||
return r.json()
|
||||
|
||||
|
||||
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
||||
'''
|
||||
|
||||