Merge branch 'dev' into pre-release

This commit is contained in:
imClumsyPanda 2023-09-05 10:02:37 +08:00
commit 44262da4e5
56 changed files with 1999 additions and 554 deletions

4
.gitignore vendored
View File

@ -3,5 +3,7 @@
logs
.idea/
__pycache__/
knowledge_base/
/knowledge_base/
configs/*.py
.vscode/
.pytest_cache/

143
README.md
View File

@ -14,8 +14,8 @@
* [2. 下载模型至本地](README.md#2.-下载模型至本地)
* [3. 设置配置项](README.md#3.-设置配置项)
* [4. 知识库初始化与迁移](README.md#4.-知识库初始化与迁移)
* [5. 启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI)
* [6. 一键启动](README.md#6.-一键启动)
* [5. 一键启动API服务或WebUI服务](README.md#6.-一键启动)
* [6. 分步启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI)
* [常见问题](README.md#常见问题)
* [路线图](README.md#路线图)
* [项目交流群](README.md#项目交流群)
@ -226,9 +226,93 @@ embedding_model_dict = {
$ python init_database.py --recreate-vs
```
### 5. 启动 API 服务或 Web UI
### 5. 一键启动API 服务或 Web UI
#### 5.1 启动 LLM 服务
#### 5.1 启动命令
一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码:
```shell
$ python startup.py -a
```
并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。
可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`,
`-m (或--model-worker)`, `--api`, `--webui`,其中:
- `--all-webui` 为一键启动 WebUI 所有依赖服务;
- `--all-api` 为一键启动 API 所有依赖服务;
- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务;
- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务;
- 其他为单独服务启动选项。
#### 5.2 启动非默认模型
若想指定非默认模型,需要用 `--model-name` 选项,示例:
```shell
$ python startup.py --all-webui --model-name Qwen-7B-Chat
```
更多信息可通过 `python startup.py -h`查看。
#### 5.3 多卡加载
项目支持多卡加载,需在 startup.py 中的 create_model_worker_app 函数中,修改如下三个参数:
```python
gpus=None,
num_gpus=1,
max_gpu_memory="20GiB"
```
其中,`gpus` 控制使用的显卡的ID例如 "0,1";
`num_gpus` 控制使用的卡数;
`max_gpu_memory` 控制每个卡使用的显存容量。
注1server_config.py的FSCHAT_MODEL_WORKERS字典中也增加了相关配置如有需要也可通过修改FSCHAT_MODEL_WORKERS字典中对应参数实现多卡加载。
注2少数情况下gpus参数会不生效此时需要通过设置环境变量CUDA_VISIBLE_DEVICES来指定torch可见的gpu,示例代码:
```shell
CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
```
#### 5.4 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等)
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.jsonpeft 路径下包含.bin 格式的 PEFT 权重peft路径在startup.py中create_model_worker_app函数的args.model_names中指定并开启环境变量PEFT_SHARE_BASE_WEIGHTS=true参数。
如果上述方式启动失败则需要以标准的fastchat服务启动方式分步启动分步启动步骤参考第六节PEFT加载详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)
#### **5.5 注意事项:**
**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP8501`)**
**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。**
**3. 在Linux上使用ctrl+C退出可能会由于linux的多进程机制导致multiprocessing遗留孤儿进程可通过shutdown_all.sh进行退出**
#### 5.6 启动界面示例:
1. FastAPI docs 界面
![](img/fastapi_docs_020_0.png)
2. webui启动界面示例
- Web UI 对话界面:
![img](img/webui_0813_0.png)
- Web UI 知识库管理页面:
![](img/webui_0813_1.png)
### 6 分步启动 API 服务或 Web UI
注意:如使用了一键启动方式,可忽略本节。
#### 6.1 启动 LLM 服务
如需使用开源模型进行本地部署,需首先启动 LLM 服务,启动方式分为三种:
@ -240,7 +324,7 @@ embedding_model_dict = {
如果启动在线的API服务如 OPENAI 的 API 接口),则无需启动 LLM 服务,即 5.1 小节的任何命令均无需启动。
##### 5.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务
##### 6.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务
在项目根目录下,执行 [server/llm_api.py](server/llm_api.py) 脚本启动 **LLM 模型**服务:
@ -248,7 +332,7 @@ embedding_model_dict = {
$ python server/llm_api.py
```
项目支持多卡加载,需在 llm_api.py 中修改 create_model_worker_app 函数中,修改如下三个参数:
项目支持多卡加载,需在 llm_api.py 中 create_model_worker_app 函数中,修改如下三个参数:
```python
gpus=None,
@ -262,11 +346,11 @@ max_gpu_memory="20GiB"
`max_gpu_memory` 控制每个卡使用的显存容量。
##### 5.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务
##### 6.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务
⚠️ **注意:**
**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wls;**
**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wsl;**
**2.加载非默认模型需要用命令行参数--model-path-address指定模型不会读取model_config.py配置;**
@ -302,14 +386,14 @@ $ python server/llm_api_shutdown.py --serve all
亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`]
##### 5.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia等)
##### 6.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等)
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.jsonpeft 路径下包含 model.bin 格式的 PEFT 权重。
详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)
![image](https://github.com/chatchat-space/Langchain-Chatchat/assets/22924096/4e056c1c-5c4b-4865-a1af-859cd58a625d)
#### 5.2 启动 API 服务
#### 6.2 启动 API 服务
本地部署情况下,按照 [5.1 节](README.md#5.1-启动-LLM-服务)**启动 LLM 服务后**,再执行 [server/api.py](server/api.py) 脚本启动 **API** 服务;
@ -327,7 +411,7 @@ $ python server/api.py
![](img/fastapi_docs_020_0.png)
#### 5.3 启动 Web UI 服务
#### 6.3 启动 Web UI 服务
按照 [5.2 节](README.md#5.2-启动-API-服务)**启动 API 服务后**,执行 [webui.py](webui.py) 启动 **Web UI** 服务(默认使用端口 `8501`
@ -356,41 +440,6 @@ $ streamlit run webui.py --server.port 666
---
### 6. 一键启动
更新一键启动脚本 startup.py,一键启动所有 Fastchat 服务、API 服务、WebUI 服务,示例代码:
```shell
$ python startup.py -a
```
并可使用 `Ctrl + C` 直接关闭所有运行服务。如果一次结束不了,可以多按几次。
可选参数包括 `-a (或--all-webui)`, `--all-api`, `--llm-api`, `-c (或--controller)`, `--openai-api`,
`-m (或--model-worker)`, `--api`, `--webui`,其中:
- `--all-webui` 为一键启动 WebUI 所有依赖服务;
- `--all-api` 为一键启动 API 所有依赖服务;
- `--llm-api` 为一键启动 Fastchat 所有依赖的 LLM 服务;
- `--openai-api` 为仅启动 FastChat 的 controller 和 openai-api-server 服务;
- 其他为单独服务启动选项。
若想指定非默认模型,需要用 `--model-name` 选项,示例:
```shell
$ python startup.py --all-webui --model-name Qwen-7B-Chat
```
更多信息可通过 `python startup.py -h`查看。
**注意:**
**1. startup 脚本用多进程方式启动各模块的服务,可能会导致打印顺序问题,请等待全部服务发起后再调用,并根据默认或指定端口调用服务(默认 LLM API 服务端口:`127.0.0.1:8888`,默认 API 服务端口:`127.0.0.1:7861`,默认 WebUI 服务端口:`本机IP8501`)**
**2.服务启动时间示设备不同而不同,约 3-10 分钟,如长时间没有启动请前往 `./logs`目录下监控日志,定位问题。**
**3. 在Linux上使用ctrl+C退出可能会由于linux的多进程机制导致multiprocessing遗留孤儿进程可通过shutdown_all.sh进行退出**
## 常见问题
参见 [常见问题](docs/FAQ.md)。
@ -433,6 +482,6 @@ $ python startup.py --all-webui --model-name Qwen-7B-Chat
## 项目交流群
<img src="img/qr_code_56.jpg" alt="二维码" width="300" height="300" />
<img src="img/qr_code_58.jpg" alt="二维码" width="300" height="300" />
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。

View File

@ -1,6 +1,5 @@
import os
import logging
import torch
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
@ -32,9 +31,8 @@ embedding_model_dict = {
# 选用的 Embedding 名称
EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto"
llm_model_dict = {
"chatglm-6b": {
@ -69,20 +67,32 @@ llm_model_dict = {
# 如果出现WARNING: Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in
# 4.0 seconds as it raised APIConnectionError: Error communicating with OpenAI.
# 需要添加代理访问(正常开的代理软件可能会拦截不上)需要设置配置openai_proxy 或者 使用环境遍历OPENAI_PROXY 进行设置
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
"gpt-3.5-turbo": {
"local_model_path": "gpt-3.5-turbo",
"api_base_url": "https://api.openai.com/v1",
"api_key": os.environ.get("OPENAI_API_KEY"),
"openai_proxy": os.environ.get("OPENAI_PROXY")
},
# 线上模型。当前支持智谱AI。
# 如果没有设置有效的local_model_path则认为是在线模型API。
# 请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn
"chatglm-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": os.environ.get("ZHIPUAI_API_KEY"),
"provider": "ChatGLMWorker",
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
},
}
# LLM 名称
LLM_MODEL = "chatglm2-6b"
# LLM 运行设备
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# 历史对话轮数
HISTORY_LEN = 3
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
LLM_DEVICE = "auto"
# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
@ -164,4 +174,4 @@ BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
ZH_TITLE_ENHANCE = False

View File

@ -1,4 +1,8 @@
from .model_config import LLM_MODEL, LLM_DEVICE
from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
import httpx
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
HTTPX_DEFAULT_TIMEOUT = 300.0
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain
@ -29,15 +33,18 @@ FSCHAT_OPENAI_API = {
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = {
LLM_MODEL: {
# 所有模型共用的默认配置可在模型专项配置或llm_model_dict中进行覆盖。
"default": {
"host": DEFAULT_BIND_HOST,
"port": 20002,
"device": LLM_DEVICE,
# todo: 多卡加载需要配置的参数
"gpus": None, # 使用的GPU以str的格式指定如"0,1"
"num_gpus": 1, # 使用GPU的数量
# 以下为非常用参数,可根据需要配置
# 多卡加载需要配置的参数
# "gpus": None, # 使用的GPU以str的格式指定如"0,1"
# "num_gpus": 1, # 使用GPU的数量
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
# 以下为非常用参数,可根据需要配置
# "load_8bit": False, # 开启8bit量化
# "cpu_offloading": None,
# "gptq_ckpt": None,
@ -53,11 +60,17 @@ FSCHAT_MODEL_WORKERS = {
# "stream_interval": 2,
# "no_register": False,
},
"baichuan-7b": { # 使用default中的IP和端口
"device": "cpu",
},
"chatglm-api": { # 请为每个在线API设置不同的端口
"port": 20003,
},
}
# fastchat multi model worker server
FSCHAT_MULTI_MODEL_WORKERS = {
# todo
# TODO:
}
# fastchat controller server
@ -66,35 +79,3 @@ FSCHAT_CONTROLLER = {
"port": 20001,
"dispatch_method": "shortest_queue",
}
# 以下不要更改
def fschat_controller_address() -> str:
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := FSCHAT_MODEL_WORKERS.get(model_name):
host = model["host"]
port = model["port"]
return f"http://{host}:{port}"
def fschat_openai_api_address() -> str:
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"
def api_address() -> str:
host = API_SERVER["host"]
port = API_SERVER["port"]
return f"http://{host}:{port}"
def webui_address() -> str:
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
return f"http://{host}:{port}"

View File

@ -0,0 +1,2 @@
from .mypdfloader import RapidOCRPDFLoader
from .myimgloader import RapidOCRLoader

View File

@ -0,0 +1,25 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
class RapidOCRLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def img2text(filepath):
from rapidocr_onnxruntime import RapidOCR
resp = ""
ocr = RapidOCR()
result, _ = ocr(filepath)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
return resp
text = img2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == "__main__":
loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg")
docs = loader.load()
print(docs)

View File

@ -0,0 +1,37 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
class RapidOCRPDFLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def pdf2text(filepath):
import fitz
from rapidocr_onnxruntime import RapidOCR
import numpy as np
ocr = RapidOCR()
doc = fitz.open(filepath)
resp = ""
for page in doc:
# TODO: 依据文本与图片顺序调整处理方式
text = page.get_text("")
resp += text + "\n"
img_list = page.get_images()
for img in img_list:
pix = fitz.Pixmap(doc, img[0])
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
result, _ = ocr(img_array)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
return resp
text = pdf2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == "__main__":
loader = RapidOCRPDFLoader(file_path="../tests/samples/ocr_test.pdf")
docs = loader.load()
print(docs)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 272 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 284 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 281 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 292 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 269 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 291 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 200 KiB

BIN
img/qr_code_58.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 249 KiB

View File

@ -1,8 +1,9 @@
from server.knowledge_base.migrate import create_tables, folder2db, recreate_all_vs, list_kbs_from_folder
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from startup import dump_server_info
from datetime import datetime
if __name__ == "__main__":
@ -25,13 +26,19 @@ if __name__ == "__main__":
dump_server_info()
create_tables()
print("database talbes created")
start_time = datetime.now()
if args.recreate_vs:
reset_tables()
print("database talbes reseted")
print("recreating all vector stores")
recreate_all_vs()
else:
create_tables()
print("database talbes created")
print("filling kb infos to database")
for kb in list_kbs_from_folder():
folder2db(kb, "fill_info_only")
end_time = datetime.now()
print(f"总计用时: {end_time-start_time}")

View File

@ -15,6 +15,8 @@ SQLAlchemy==2.0.19
faiss-cpu
accelerate
spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.1
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -16,6 +16,8 @@ faiss-cpu
nltk
accelerate
spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.1
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -4,20 +4,22 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
from configs import VERSION
import argparse
import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
import httpx
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -84,11 +86,11 @@ def create_app():
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_docs",
app.get("/knowledge_base/list_files",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_docs)
)(list_files)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
@ -123,6 +125,75 @@ def create_app():
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
# LLM模型相关接口
@app.post("/llm_model/list_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型")
def list_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
@app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
return app

View File

@ -20,11 +20,13 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
):
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
@ -32,10 +34,10 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
@ -61,5 +63,5 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
await task
return StreamingResponse(chat_iterator(query, history),
return StreamingResponse(chat_iterator(query, history, model_name),
media_type="text/event-stream")

View File

@ -31,6 +31,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
):
@ -44,16 +45,17 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
kb: KBService,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
@ -97,5 +99,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
await task
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
media_type="text/event-stream")

View File

@ -29,23 +29,22 @@ async def openai_chat(msg: OpenAiChatMsgIn):
print(f"{openai.api_base=}")
print(msg)
async def get_response(msg):
def get_response(msg):
data = msg.dict()
data["streaming"] = True
data.pop("stream")
try:
response = openai.ChatCompletion.create(**data)
if msg.stream:
for chunk in response.choices[0].message.content:
print(chunk)
yield chunk
for data in response:
if choices := data.choices:
if chunk := choices[0].get("delta", {}).get("content"):
print(chunk, end="", flush=True)
yield chunk
else:
answer = ""
for chunk in response.choices[0].message.content:
answer += chunk
print(answer)
yield(answer)
if response.choices:
answer = response.choices[0].message.content
print(answer)
yield(answer)
except Exception as e:
print(type(e))
logger.error(e)

View File

@ -69,6 +69,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -82,16 +83,17 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL,
openai_proxy=llm_model_dict[LLM_MODEL].get("openai_proxy")
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = lookup_search_engine(query, search_engine_name, top_k)
@ -129,5 +131,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
media_type="text/event-stream")

View File

@ -3,10 +3,14 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from configs.model_config import SQLALCHEMY_DATABASE_URI
import json
engine = create_engine(SQLALCHEMY_DATABASE_URI)
engine = create_engine(
SQLALCHEMY_DATABASE_URI,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

View File

@ -9,9 +9,9 @@ class KnowledgeBaseModel(Base):
"""
__tablename__ = 'knowledge_base'
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识库ID')
kb_name = Column(String, comment='知识库名称')
vs_type = Column(String, comment='嵌入模型类型')
embed_model = Column(String, comment='嵌入模型名称')
kb_name = Column(String(50), comment='知识库名称')
vs_type = Column(String(50), comment='向量库类型')
embed_model = Column(String(50), comment='嵌入模型名称')
file_count = Column(Integer, default=0, comment='文件数量')
create_time = Column(DateTime, default=func.now(), comment='创建时间')

View File

@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
from server.db.base import Base
@ -9,13 +9,32 @@ class KnowledgeFileModel(Base):
"""
__tablename__ = 'knowledge_file'
id = Column(Integer, primary_key=True, autoincrement=True, comment='知识文件ID')
file_name = Column(String, comment='文件名')
file_ext = Column(String, comment='文件扩展名')
kb_name = Column(String, comment='所属知识库名称')
document_loader_name = Column(String, comment='文档加载器名称')
text_splitter_name = Column(String, comment='文本分割器名称')
file_name = Column(String(255), comment='文件名')
file_ext = Column(String(10), comment='文件扩展名')
kb_name = Column(String(50), comment='所属知识库名称')
document_loader_name = Column(String(50), comment='文档加载器名称')
text_splitter_name = Column(String(50), comment='文本分割器名称')
file_version = Column(Integer, default=1, comment='文件版本')
file_mtime = Column(Float, default=0.0, comment="文件修改时间")
file_size = Column(Integer, default=0, comment="文件大小")
custom_docs = Column(Boolean, default=False, comment="是否自定义docs")
docs_count = Column(Integer, default=0, comment="切分文档数量")
create_time = Column(DateTime, default=func.now(), comment='创建时间')
def __repr__(self):
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"
class FileDocModel(Base):
"""
文件-向量库文档模型
"""
__tablename__ = 'file_doc'
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
kb_name = Column(String(50), comment='知识库名称')
file_name = Column(String(255), comment='文件名称')
doc_id = Column(String(50), comment="向量库文档ID")
meta_data = Column(JSON, default={})
def __repr__(self):
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.metadata}')>"

View File

@ -1,24 +1,101 @@
from server.db.models.knowledge_base_model import KnowledgeBaseModel
from server.db.models.knowledge_file_model import KnowledgeFileModel
from server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel
from server.db.session import with_session
from server.knowledge_base.utils import KnowledgeFile
from typing import List, Dict
@with_session
def list_docs_from_db(session, kb_name):
def list_docs_from_db(session,
kb_name: str,
file_name: str = None,
metadata: Dict = {},
) -> List[Dict]:
'''
列出某知识库某文件对应的所有Document
返回形式[{"id": str, "metadata": dict}, ...]
'''
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
if file_name:
docs = docs.filter_by(file_name=file_name)
for k, v in metadata.items():
docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))
return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()]
@with_session
def delete_docs_from_db(session,
kb_name: str,
file_name: str = None,
) -> List[Dict]:
'''
删除某知识库某文件对应的所有Document并返回被删除的Document
返回形式[{"id": str, "metadata": dict}, ...]
'''
docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
query = session.query(FileDocModel).filter_by(kb_name=kb_name)
if file_name:
query = query.filter_by(file_name=file_name)
query.delete()
session.commit()
return docs
@with_session
def add_docs_to_db(session,
kb_name: str,
file_name: str,
doc_infos: List[Dict]):
'''
将某知识库某文件对应的所有Document信息添加到数据库
doc_infos形式[{"id": str, "metadata": dict}, ...]
'''
for d in doc_infos:
obj = FileDocModel(
kb_name=kb_name,
file_name=file_name,
doc_id=d["id"],
meta_data=d["metadata"],
)
session.add(obj)
return True
@with_session
def count_files_from_db(session, kb_name: str) -> int:
return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count()
@with_session
def list_files_from_db(session, kb_name):
files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all()
docs = [f.file_name for f in files]
return docs
@with_session
def add_doc_to_db(session, kb_file: KnowledgeFile):
def add_file_to_db(session,
kb_file: KnowledgeFile,
docs_count: int = 0,
custom_docs: bool = False,
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...]
):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb:
# 如果已经存在该文件,则更新文件版本号
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
# 如果已经存在该文件,则更新文件信息与版本号
existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel)
.filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name)
.first())
mtime = kb_file.get_mtime()
size = kb_file.get_size()
if existing_file:
existing_file.file_mtime = mtime
existing_file.file_size = size
existing_file.docs_count = docs_count
existing_file.custom_docs = custom_docs
existing_file.file_version += 1
# 否则,添加新文件
else:
@ -28,9 +105,14 @@ def add_doc_to_db(session, kb_file: KnowledgeFile):
kb_name=kb_file.kb_name,
document_loader_name=kb_file.document_loader_name,
text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter",
file_mtime=mtime,
file_size=size,
docs_count = docs_count,
custom_docs=custom_docs,
)
kb.file_count += 1
session.add(new_file)
add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos)
return True
@ -40,6 +122,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
kb_name=kb_file.kb_name).first()
if existing_file:
session.delete(existing_file)
delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
session.commit()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
@ -52,7 +135,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
@with_session
def delete_files_from_db(session, knowledge_base_name: str):
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
if kb:
kb.file_count = 0
@ -62,7 +145,7 @@ def delete_files_from_db(session, knowledge_base_name: str):
@with_session
def doc_exists(session, kb_file: KnowledgeFile):
def file_exists_in_db(session, kb_file: KnowledgeFile):
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
return True if existing_file else False
@ -82,6 +165,10 @@ def get_file_detail(session, kb_name: str, filename: str) -> dict:
"document_loader": file.document_loader_name,
"text_splitter": file.text_splitter_name,
"create_time": file.create_time,
"file_mtime": file.file_mtime,
"file_size": file.file_size,
"custom_docs": file.custom_docs,
"docs_count": file.docs_count,
}
else:
return {}

View File

@ -3,7 +3,7 @@ import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
from fastapi.responses import StreamingResponse, FileResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
@ -29,7 +29,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
return data
async def list_docs(
async def list_files(
knowledge_base_name: str
) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
@ -40,7 +40,7 @@ async def list_docs(
if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_docs()
all_doc_names = kb.list_files()
return ListResponse(data=all_doc_names)
@ -183,14 +183,14 @@ async def recreate_vector_store(
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
'''
async def output():
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
kb.create_kb()
kb.clear_vs()
docs = list_docs_from_folder(knowledge_base_name)
docs = list_files_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)

View File

@ -1,25 +1,32 @@
import operator
from abc import ABC, abstractmethod
import os
import numpy as np
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from sklearn.preprocessing import normalize
from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail,
)
from server.db.repository.knowledge_file_repository import (
add_doc_to_db, delete_file_from_db, delete_files_from_db, doc_exists,
list_docs_from_db, get_file_detail, delete_file_from_db
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
list_docs_from_db,
)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL)
EMBEDDING_MODEL)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_docs_from_folder,
list_kbs_from_folder, list_files_from_folder,
)
from server.utils import embedding_device
from typing import List, Union, Dict
from typing import List, Union, Dict, Optional
class SupportedVSType:
@ -41,7 +48,7 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name)
self.do_init()
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def create_kb(self):
@ -62,7 +69,6 @@ class KBService(ABC):
status = delete_files_from_db(self.kb_name)
return status
def drop_kb(self):
"""
删除知识库
@ -71,16 +77,24 @@ class KBService(ABC):
status = delete_kb_from_db(self.kb_name)
return status
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
向知识库添加文件
如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True
"""
docs = kb_file.file2text()
if docs:
custom_docs = True
else:
docs = kb_file.file2text()
custom_docs = False
if docs:
self.delete_doc(kb_file)
embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings, **kwargs)
status = add_doc_to_db(kb_file)
doc_infos = self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file,
custom_docs=custom_docs,
docs_count=len(docs),
doc_infos=doc_infos)
else:
status = False
return status
@ -95,20 +109,24 @@ class KBService(ABC):
os.remove(kb_file.filepath)
return status
def update_doc(self, kb_file: KnowledgeFile, **kwargs):
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
使用content中的文件更新向量库
如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, **kwargs)
return self.add_doc(kb_file, docs=docs, **kwargs)
def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name))
def list_docs(self):
return list_docs_from_db(self.kb_name)
def list_files(self):
return list_files_from_db(self.kb_name)
def count_files(self):
return count_files_from_db(self.kb_name)
def search_docs(self,
query: str,
@ -119,6 +137,18 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs
# TODO: milvus/pg需要实现该方法
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
return docs
@abstractmethod
def do_create_kb(self):
"""
@ -168,8 +198,7 @@ class KBService(ABC):
@abstractmethod
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings,
):
) -> List[Dict]:
"""
向知识库添加文档子类实自己逻辑
"""
@ -208,8 +237,9 @@ class KBServiceFactory:
return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
return MilvusKBService(kb_name,
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)
@ -217,7 +247,7 @@ class KBServiceFactory:
def get_service_by_name(kb_name: str
) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name)
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
vs_type = "faiss"
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@ -256,29 +286,30 @@ def get_kb_details() -> List[Dict]:
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
return data
def get_kb_doc_details(kb_name: str) -> List[Dict]:
def get_kb_file_details(kb_name: str) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(kb_name)
docs_in_folder = list_docs_from_folder(kb_name)
docs_in_db = kb.list_docs()
files_in_folder = list_files_from_folder(kb_name)
files_in_db = kb.list_files()
result = {}
for doc in docs_in_folder:
for doc in files_in_folder:
result[doc] = {
"kb_name": kb_name,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"docs_count": 0,
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
for doc in docs_in_db:
for doc in files_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
@ -292,5 +323,39 @@ def get_kb_doc_details(kb_name: str) -> List[Dict]:
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
return data
class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embeddings: Embeddings):
self.embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return normalize(self.embeddings.embed_documents(texts))
def embed_query(self, text: str) -> List[float]:
query_embed = self.embeddings.embed_query(text)
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return await normalize(self.embeddings.aembed_documents(texts))
async def aembed_query(self, text: str) -> List[float]:
return await normalize(self.embeddings.aembed_query(text))
def score_threshold_process(score_threshold, k, docs):
if score_threshold is not None:
cmp = (
operator.le
)
docs = [
(doc, similarity)
for doc, similarity in docs
if cmp(similarity, score_threshold)
]
return docs[:k]

View File

@ -13,7 +13,7 @@ class DefaultKBService(KBService):
def do_drop_kb(self):
pass
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
def do_add_doc(self, docs: List[Document]):
pass
def do_clear_vs(self):

View File

@ -5,7 +5,6 @@ from configs.model_config import (
KB_ROOT_PATH,
CACHED_VS_NUM,
EMBEDDING_MODEL,
EMBEDDING_DEVICE,
SCORE_THRESHOLD
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
@ -13,40 +12,22 @@ from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings,HuggingFaceBgeEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from typing import List
from typing import List, Dict, Optional
from langchain.docstore.document import Document
from server.utils import torch_gc
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
_VECTOR_STORE_TICKS = {}
from server.utils import torch_gc, embedding_device
_VECTOR_STORE_TICKS = {}
@lru_cache(CACHED_VS_NUM)
def load_vector_store(
def load_faiss_vector_store(
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = EMBEDDING_DEVICE,
embed_device: str = embedding_device(),
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
):
) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
@ -86,22 +67,39 @@ class FaissKBService(KBService):
def vs_type(self) -> str:
return SupportedVSType.FAISS
@staticmethod
def get_vs_path(knowledge_base_name: str):
return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store")
def get_vs_path(self):
return os.path.join(self.get_kb_path(), "vector_store")
@staticmethod
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> FAISS:
return load_faiss_vector_store(
knowledge_base_name=self.kb_name,
embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
)
def save_vector_store(self, vector_store: FAISS = None):
vector_store = vector_store or self.load_vector_store()
vector_store.save_local(self.vs_path)
return vector_store
def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name)
def get_doc_by_id(self, id: str) -> Optional[Document]:
vector_store = self.load_vector_store()
return vector_store.docstore._dict.get(id)
def do_init(self):
self.kb_path = FaissKBService.get_kb_path(self.kb_name)
self.vs_path = FaissKBService.get_vs_path(self.kb_name)
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
load_vector_store(self.kb_name)
self.load_vector_store()
def do_drop_kb(self):
self.clear_vs()
@ -113,33 +111,27 @@ class FaissKBService(KBService):
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
search_index = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
search_index = self.load_vector_store()
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings,
**kwargs,
):
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
vector_store.add_documents(docs)
) -> List[Dict]:
vector_store = self.load_vector_store()
ids = vector_store.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
self.refresh_vs_cache()
return doc_infos
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
embeddings = self._load_embeddings()
vector_store = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
@ -148,14 +140,14 @@ class FaissKBService(KBService):
vector_store.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
self.refresh_vs_cache()
return True
return vector_store
def do_clear_vs(self):
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
refresh_vs_cache(self.kb_name)
self.refresh_vs_cache()
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
@ -166,3 +158,11 @@ class FaissKBService(KBService):
return "in_folder"
else:
return False
if __name__ == '__main__':
faissService = FaissKBService("test")
faissService.add_doc(KnowledgeFile("README.md", "test"))
faissService.delete_doc(KnowledgeFile("README.md", "test"))
faissService.do_drop_kb()
print(faissService.search_docs("如何启动api服务"))

View File

@ -1,12 +1,16 @@
from typing import List
from typing import List, Dict, Optional
import numpy as np
from faiss import normalize_L2
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Milvus
from sklearn.preprocessing import normalize
from configs.model_config import SCORE_THRESHOLD, kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
@ -18,6 +22,14 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
def get_doc_by_id(self, id: str) -> Optional[Document]:
if self.milvus.col:
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
if len(data_list) > 0:
data = data_list[0]
text = data.pop("text")
return Document(page_content=text, metadata=data)
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
@ -36,38 +48,31 @@ class MilvusKBService(KBService):
def _load_milvus(self, embeddings: Embeddings = None):
if embeddings is None:
embeddings = self._load_embeddings()
self.milvus = Milvus(embedding_function=embeddings,
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
def do_init(self):
self._load_milvus()
def do_drop_kb(self):
self.milvus.col.drop()
if self.milvus.col:
self.milvus.col.drop()
def do_search(self, query: str, top_k: int,score_threshold: float, embeddings: Embeddings):
# todo: support score threshold
self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search_with_score(query, top_k)
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
docs = kb_file.file2text()
self.milvus.add_documents(docs)
from server.db.repository.knowledge_file_repository import add_doc_to_db
status = add_doc_to_db(kb_file)
return status
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
self.milvus.col.delete(expr=f'pk in {delete_list}')
if self.milvus.col:
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
self.milvus.col.delete(expr=f'pk in {delete_list}')
def do_clear_vs(self):
if self.milvus.col:
@ -80,7 +85,9 @@ if __name__ == '__main__':
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("test")
milvusService.add_doc(KnowledgeFile("README.md", "test"))
milvusService.delete_doc(KnowledgeFile("README.md", "test"))
milvusService.do_drop_kb()
print(milvusService.search_docs("测试"))
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
print(milvusService.get_doc_by_id("444022434274215486"))
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
# milvusService.do_drop_kb()
# print(milvusService.search_docs("如何启动api服务"))

View File

@ -1,25 +1,38 @@
from typing import List
import json
from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
from server.utils import embedding_device as get_embedding_device
class PGKBService(KBService):
pg_vector: PGVector
def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
_embeddings = embeddings
if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device)
self.pg_vector = PGVector(embedding_function=_embeddings,
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_id(self, id: str) -> Optional[Document]:
with self.pg_vector.connect() as connect:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
results = [Document(page_content=row[0], metadata=row[1]) for row in
connect.execute(stmt, parameters={'id': id}).fetchall()]
if len(results) > 0:
return results[0]
def do_init(self):
self._load_pg_vector()
@ -44,22 +57,14 @@ class PGKBService(KBService):
connect.commit()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search_with_score(query, top_k)
return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k))
def add_doc(self, kb_file: KnowledgeFile, **kwargs):
"""
向知识库添加文件
"""
docs = kb_file.file2text()
self.pg_vector.add_documents(docs)
from server.db.repository.knowledge_file_repository import add_doc_to_db
status = add_doc_to_db(kb_file)
return status
def do_add_doc(self, docs: List[Document], embeddings: Embeddings, **kwargs):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:
@ -77,10 +82,11 @@ class PGKBService(KBService):
if __name__ == '__main__':
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
# Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test")
pGKBService.create_kb()
pGKBService.add_doc(KnowledgeFile("README.md", "test"))
pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
pGKBService.drop_kb()
print(pGKBService.search_docs("测试"))
# pGKBService.create_kb()
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
# pGKBService.drop_kb()
print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772"))
# print(pGKBService.search_docs("如何启动api服务"))

View File

@ -1,10 +1,17 @@
from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE
from server.knowledge_base.utils import get_file_path, list_kbs_from_folder, list_docs_from_folder, KnowledgeFile
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import add_doc_to_db
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder, run_in_thread_pool,
files2docs_in_thread,
KnowledgeFile,)
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.base import Base, engine
import os
from typing import Literal, Callable, Any
from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Any, List
pool = ThreadPoolExecutor(os.cpu_count())
def create_tables():
@ -16,13 +23,22 @@ def reset_tables():
create_tables()
def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
kb_files = []
for file in files:
try:
kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name)
kb_files.append(kb_file)
except Exception as e:
print(f"{e},已跳过")
return kb_files
def folder2db(
kb_name: str,
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
callback_before: Callable = None,
callback_after: Callable = None,
):
'''
use existed files in local folder to populate database and/or vector store.
@ -36,70 +52,62 @@ def folder2db(
kb.create_kb()
if mode == "recreate_vs":
files_count = kb.count_files()
print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。")
kb.clear_vs()
docs = list_docs_from_folder(kb_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
print(e)
files_count = kb.count_files()
print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。")
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
for success, result in files2docs_in_thread(kb_files, pool=pool):
if success:
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "fill_info_only":
docs = list_docs_from_folder(kb_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
add_doc_to_db(kb_file)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
print(e)
files = list_files_from_folder(kb_name)
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_file:
add_file_to_db(kb_file)
print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
elif mode == "update_in_db":
docs = kb.list_docs()
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
print(e)
files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.update_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "increament":
db_docs = kb.list_docs()
folder_docs = list_docs_from_folder(kb_name)
docs = list(set(folder_docs) - set(db_docs))
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, kb_name)
if callable(callback_before):
callback_before(kb_file, i, docs)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
if callable(callback_after):
callback_after(kb_file, i, docs)
except Exception as e:
print(e)
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
for success, result in files2docs_in_thread(kb_files, pool=pool):
if success:
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
else:
raise ValueError(f"unspported migrate mode: {mode}")
print(f"unspported migrate mode: {mode}")
def recreate_all_vs(
@ -114,30 +122,34 @@ def recreate_all_vs(
folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs)
def prune_db_docs(kb_name: str):
def prune_db_files(kb_name: str):
'''
delete docs in database that not existed in local folder.
it is used to delete database docs after user deleted some doc files in file browser
delete files in database that not existed in local folder.
it is used to delete database files after user deleted some doc files in file browser
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
docs_in_db = kb.list_docs()
docs_in_folder = list_docs_from_folder(kb_name)
docs = list(set(docs_in_db) - set(docs_in_folder))
for doc in docs:
kb.delete_doc(KnowledgeFile(doc, kb_name))
return docs
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
return kb_files
def prune_folder_docs(kb_name: str):
def prune_folder_files(kb_name: str):
'''
delete doc files in local folder that not existed in database.
is is used to free local disk space by delete unused doc files.
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
docs_in_db = kb.list_docs()
docs_in_folder = list_docs_from_folder(kb_name)
docs = list(set(docs_in_folder) - set(docs_in_db))
for doc in docs:
os.remove(get_file_path(kb_name, doc))
return docs
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_folder) - set(files_in_db))
for file in files:
os.remove(get_file_path(kb_name, file))
return files

View File

@ -12,6 +12,26 @@ from configs.model_config import (
from functools import lru_cache
import importlib
from text_splitter import zh_title_enhance
import langchain.document_loaders
from langchain.docstore.document import Document
from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
def validate_kb_name(knowledge_base_id: str) -> bool:
@ -20,27 +40,34 @@ def validate_kb_name(knowledge_base_id: str) -> bool:
return False
return True
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
def list_kbs_from_folder():
return [f for f in os.listdir(KB_ROOT_PATH)
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
def list_docs_from_folder(kb_name: str):
def list_files_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
return [file for file in os.listdir(doc_path)
if os.path.isfile(os.path.join(doc_path, file))]
@lru_cache(1)
def load_embeddings(model: str, device: str):
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
@ -56,16 +83,90 @@ def load_embeddings(model: str, device: str):
return embeddings
LOADER_DICT = {"UnstructuredFileLoader": ['.eml', '.html', '.json', '.md', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt', '.pdf',
'.ppt', '.pptx', '.tsv'], # '.pdf', '.xlsx', '.csv'
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredMarkdownLoader": ['.md'],
"CustomJSONLoader": [".json"],
"CSVLoader": [".csv"],
"PyPDFLoader": [".pdf"],
"RapidOCRPDFLoader": [".pdf"],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.doc', '.docx', '.epub', '.odt',
'.ppt', '.pptx', '.tsv'], # '.xlsx'
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
'''
langchain的JSONLoader需要jq在win上使用不便进行替代
'''
def __init__(
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
):
"""Initialize the JSONLoader.
Args:
file_path (Union[str, Path]): The path to the JSON or JSON Lines file.
content_key (str): The key to use to extract the content from the JSON if
results to a list of objects (dict).
metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
object extracted by the jq_schema and the default metadata and returns
a dict of the updated metadata.
text_content (bool): Boolean flag to indicate whether the content is in
string format, default to True.
json_lines (bool): Boolean flag to indicate whether the input is in
JSON Lines format.
"""
self.file_path = Path(file_path).resolve()
self._content_key = content_key
self._metadata_func = metadata_func
self._text_content = text_content
self._json_lines = json_lines
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
def load(self) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
if self._json_lines:
with self.file_path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self._parse(line, docs)
else:
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
return docs
def _parse(self, content: str, docs: List[Document]) -> None:
"""Convert given content to documents."""
data = json.loads(content)
# Perform some validation
# This is not a perfect validation, but it should catch most cases
# and prevent the user from getting a cryptic error later on.
if self._content_key is not None:
self._validate_content_key(data)
for i, sample in enumerate(data, len(docs) + 1):
metadata = dict(
source=str(self.file_path),
seq_num=i,
)
text = self._get_text(sample=sample, metadata=metadata)
docs.append(Document(page_content=text, metadata=metadata))
langchain.document_loaders.CustomJSONLoader = CustomJSONLoader
def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER_DICT.items():
if file_extension in extensions:
@ -90,10 +191,16 @@ class KnowledgeFile:
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = None
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE):
print(self.document_loader_name)
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
if self.docs is not None and not refresh:
return self.docs
print(f"{self.document_loader_name} used for {self.filepath}")
try:
document_loaders_module = importlib.import_module('langchain.document_loaders')
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
except Exception as e:
print(e)
@ -101,6 +208,16 @@ class KnowledgeFile:
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
elif self.document_loader_name == "CSVLoader":
loader = DocumentLoader(self.filepath, encoding="utf-8")
elif self.document_loader_name == "JSONLoader":
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
elif self.document_loader_name == "CustomJSONLoader":
loader = DocumentLoader(self.filepath, text_content=False)
elif self.document_loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(self.filepath, mode="elements")
elif self.document_loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(self.filepath, mode="elements")
else:
loader = DocumentLoader(self.filepath)
@ -136,4 +253,63 @@ class KnowledgeFile:
print(docs[0])
if using_zh_title_enhance:
docs = zh_title_enhance(docs)
self.docs = docs
return docs
def get_mtime(self):
return os.path.getmtime(self.filepath)
def get_size(self):
return os.path.getsize(self.filepath)
def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
在线程池中批量运行任务并将运行结果以生成器的形式返回
请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数
'''
tasks = []
if pool is None:
pool = ThreadPoolExecutor()
for kwargs in params:
thread = pool.submit(func, **kwargs)
tasks.append(thread)
for obj in as_completed(tasks):
yield obj.result()
def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
利用多线程批量将文件转化成langchain Document.
生成器返回值为{(kb_name, file_name): docs}
'''
def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
return False, e
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
if isinstance(file, tuple) and len(file) >= 2:
files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs = file
kwargs["file"] = file
kwargs_list.append(kwargs)
for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
yield result

View File

@ -4,8 +4,8 @@ import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
from server.utils import MakeFastAPIOffline
from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
host_ip = "0.0.0.0"
@ -15,13 +15,6 @@ openai_api_port = 8888
base_url = "http://127.0.0.1:{}"
def set_httpx_timeout(timeout=60.0):
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
def create_controller_app(
dispatch_method="shortest_queue",
):
@ -41,7 +34,7 @@ def create_model_worker_app(
worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
device=LLM_DEVICE,
device=llm_device(),
gpus=None,
max_gpu_memory="20GiB",
load_8bit=False,

View File

@ -0,0 +1 @@
from .zhipu import ChatGLMWorker

View File

@ -0,0 +1,71 @@
from configs.model_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker
import uuid
import json
import sys
from pydantic import BaseModel
import fastchat
import threading
from typing import Dict, List
# 恢复被fastchat覆盖的标准输出
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
class ApiModelOutMsg(BaseModel):
error_code: int = 0
text: str
class ApiModelWorker(BaseModelWorker):
BASE_URL: str
SUPPORT_MODELS: List
def __init__(
self,
model_names: List[str],
controller_addr: str,
worker_addr: str,
context_len: int = 2048,
**kwargs,
):
kwargs.setdefault("worker_id", uuid.uuid4().hex[:8])
kwargs.setdefault("model_path", "")
kwargs.setdefault("limit_worker_concurrency", 5)
super().__init__(model_names=model_names,
controller_addr=controller_addr,
worker_addr=worker_addr,
**kwargs)
self.context_len = context_len
self.init_heart_beat()
def count_token(self, params):
# TODO需要完善
print("count token")
print(params)
prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0}
def generate_stream_gate(self, params):
self.call_ct += 1
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
def get_embeddings(self, params):
print("embedding")
print(params)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()

View File

@ -0,0 +1,75 @@
import zhipuai
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
from typing import List, Literal
class ChatGLMWorker(ApiModelWorker):
BASE_URL = "https://open.bigmodel.cn/api/paas/v3/model-api"
SUPPORT_MODELS = ["chatglm_pro", "chatglm_std", "chatglm_lite"]
def __init__(
self,
*,
model_names: List[str] = ["chatglm-api"],
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version
# 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation(
name="chatglm-api",
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["Human", "Assistant"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
# TODO: 支持stream参数维护request_id传过来的prompt也有问题
from server.utils import get_model_worker_config
super().generate_stream_gate(params)
zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
response = zhipuai.model_api.sse_invoke(
model=self.version,
prompt=[{"role": "user", "content": params["prompt"]}],
temperature=params.get("temperature"),
top_p=params.get("top_p"),
incremental=False,
)
for e in response.events():
if e.event == "add":
yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0"
# TODO: 更健壮的消息处理
# elif e.event == "finish":
# ...
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20003)

View File

@ -1,16 +1,20 @@
import pydantic
from pydantic import BaseModel
from typing import List
import torch
from fastapi import FastAPI
from pathlib import Path
import asyncio
from typing import Any, Optional
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE
from configs.server_config import FSCHAT_MODEL_WORKERS
import os
from server import model_workers
from typing import Literal, Optional, Any
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
data: Any = pydantic.Field(None, description="API data")
class Config:
schema_extra = {
@ -68,6 +72,7 @@ class ChatMessage(BaseModel):
}
def torch_gc():
import torch
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()
@ -186,3 +191,117 @@ def MakeFastAPIOffline(
with_google_fonts=False,
redoc_favicon_url=favicon,
)
# 从server_config中获取服务信息
def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
'''
加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import llm_model_dict
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
# 如果没有设置有效的local_model_path则认为是在线模型API
if not os.path.isdir(config.get("local_model_path", "")):
config["online_api"] = True
if provider := config.get("provider"):
try:
config["worker_class"] = getattr(model_workers, provider)
except Exception as e:
print(f"在线模型 {model_name} 的provider没有正确配置")
config["device"] = llm_device(config.get("device") or LLM_DEVICE)
return config
def get_all_model_worker_configs() -> dict:
result = {}
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
for name in model_names:
if name != "default":
result[name] = get_model_worker_config(name)
return result
def fschat_controller_address() -> str:
from configs.server_config import FSCHAT_CONTROLLER
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
return f"http://{host}:{port}"
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
if model := get_model_worker_config(model_name):
host = model["host"]
port = model["port"]
return f"http://{host}:{port}"
return ""
def fschat_openai_api_address() -> str:
from configs.server_config import FSCHAT_OPENAI_API
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"
def api_address() -> str:
from configs.server_config import API_SERVER
host = API_SERVER["host"]
port = API_SERVER["port"]
return f"http://{host}:{port}"
def webui_address() -> str:
from configs.server_config import WEBUI_SERVER
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
return f"http://{host}:{port}"
def set_httpx_timeout(timeout: float = None):
'''
设置httpx默认timeout
httpx默认timeout是5秒在请求LLM回答时不够用
'''
import httpx
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
timeout = timeout or HTTPX_DEFAULT_TIMEOUT
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device

View File

@ -1,59 +1,60 @@
from multiprocessing import Process, Queue
import asyncio
import multiprocessing as mp
import os
import subprocess
import sys
import os
from multiprocessing import Process, Queue
from pprint import pprint
# 设置numexpr最大线程数默认为CPU核心数
try:
import numexpr
n_cores = numexpr.utils.detect_number_of_cores()
os.environ["NUMEXPR_MAX_THREADS"] = str(n_cores)
except:
pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, )
from server.utils import MakeFastAPIOffline, FastAPI
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse
from typing import Tuple, List
from typing import Tuple, List, Dict
from configs import VERSION
def set_httpx_timeout(timeout=60.0):
import httpx
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
def create_controller_app(
dispatch_method: str,
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.controller import app, Controller
from fastchat.serve.controller import app, Controller, logger
logger.setLevel(log_level)
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
app._controller = controller
return app
def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
import argparse
import threading
import fastchat.serve.model_worker
logger.setLevel(log_level)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
@ -99,53 +100,76 @@ def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
# 在线模型API
if worker_class := kwargs.get("worker_class"):
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
# 本地模型
else:
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
ModelWorker.init_heart_beat = _new_init_heart_beat
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})"
app.title = f"FastChat LLM Server ({args.model_names[0]})"
app._worker = worker
return app
def create_openai_api_app(
controller_address: str,
api_keys: List = [],
log_level: str = "INFO",
) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
from fastchat.utils import build_logger
logger = build_logger("openai_api", "openai_api.log")
logger.setLevel(log_level)
app.add_middleware(
CORSMiddleware,
@ -155,6 +179,7 @@ def create_openai_api_app(
allow_headers=["*"],
)
sys.modules["fastchat.serve.openai_api_server"].logger = logger
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
@ -164,6 +189,9 @@ def create_openai_api_app(
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
if q is None or not isinstance(run_seq, int):
return
if run_seq == 1:
@app.on_event("startup")
async def on_startup():
@ -182,15 +210,90 @@ def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
q.put(run_seq)
def run_controller(q: Queue, run_seq: int = 1):
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
import uvicorn
import httpx
from fastapi import Body
import time
import sys
app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method"))
app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
log_level=log_level,
)
_set_app_seq(app, q, run_seq)
@app.on_event("startup")
def on_startup():
if e is not None:
e.set()
# add interface to release and load model worker
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
available_models = app._controller.list_models()
if new_model_name in available_models:
msg = f"要切换的LLM模型 {new_model_name} 已经存在"
logger.info(msg)
return {"code": 500, "msg": msg}
if new_model_name:
logger.info(f"开始切换LLM模型{model_name}{new_model_name}")
else:
logger.info(f"即将停止LLM模型 {model_name}")
if model_name not in available_models:
msg = f"the model {model_name} is not available"
logger.error(msg)
return {"code": 500, "msg": msg}
worker_address = app._controller.get_worker_address(model_name)
if not worker_address:
msg = f"can not find model_worker address for {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
r = httpx.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
if new_model_name:
timer = 300 # wait 5 minutes for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
break
time.sleep(1)
timer -= 1
if timer > 0:
msg = f"sucess change model from {model_name} to {new_model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
else:
msg = f"failed change model from {model_name} to {new_model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
else:
msg = f"sucess to release model: {model_name}"
logger.info(msg)
return {"code": 200, "msg": msg}
host = FSCHAT_CONTROLLER["host"]
port = FSCHAT_CONTROLLER["port"]
uvicorn.run(app, host=host, port=port)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_model_worker(
@ -198,33 +301,59 @@ def run_model_worker(
controller_address: str = "",
q: Queue = None,
run_seq: int = 2,
log_level: str = "INFO",
):
import uvicorn
from fastapi import Body
import sys
kwargs = FSCHAT_MODEL_WORKERS[model_name].copy()
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "")
kwargs["model_path"] = model_path
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address()
kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("local_model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(**kwargs)
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port)
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
q.put(["start", new_model_name])
else:
if new_model_name:
q.put(["replace", new_model_name])
else:
q.put(["stop"])
return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_openai_api(q: Queue, run_seq: int = 3):
def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
import uvicorn
import sys
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr) # todo: not support keys yet.
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
_set_app_seq(app, q, run_seq)
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
uvicorn.run(app, host=host, port=port)
@ -244,13 +373,15 @@ def run_api_server(q: Queue, run_seq: int = 4):
def run_webui(q: Queue, run_seq: int = 5):
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
if q is not None and isinstance(run_seq, int):
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
p = subprocess.Popen(["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port)])
@ -313,6 +444,13 @@ def parse_args() -> argparse.ArgumentParser:
help="run api.py server",
dest="api",
)
parser.add_argument(
"-p",
"--api-worker",
action="store_true",
help="run online model api such as zhipuai",
dest="api_worker",
)
parser.add_argument(
"-w",
"--webui",
@ -320,26 +458,38 @@ def parse_args() -> argparse.ArgumentParser:
help="run webui.py server",
dest="webui",
)
parser.add_argument(
"-q",
"--quiet",
action="store_true",
help="减少fastchat服务log信息",
dest="quiet",
)
args = parser.parse_args()
return args
return args, parser
def dump_server_info(after_start=False):
def dump_server_info(after_start=False, args=None):
import platform
import langchain
import fastchat
from configs.server_config import api_address, webui_address
from server.utils import api_address, webui_address
print("\n\n")
print("\n")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print(f"操作系统:{platform.platform()}.")
print(f"python版本{sys.version}")
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}")
pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
model = LLM_MODEL
if args and args.model_name:
model = args.model_name
print(f"当前LLM模型{model} @ {llm_device()}")
pprint(llm_model_dict[model])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
print("\n")
print(f"服务端运行信息:")
@ -351,110 +501,231 @@ def dump_server_info(after_start=False):
if args.webui:
print(f" Chatchat WEBUI Server: {webui_address()}")
print("=" * 30 + "Langchain-Chatchat Configuration" + "=" * 30)
print("\n\n")
print("\n")
if __name__ == "__main__":
async def start_main_server():
import time
import signal
def handler(signalname):
"""
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
"""
def f(signal_received, frame):
raise KeyboardInterrupt(f"{signalname} received")
return f
# This will be inherited by the child process if it is forked (not spawned)
signal.signal(signal.SIGINT, handler("SIGINT"))
signal.signal(signal.SIGTERM, handler("SIGTERM"))
mp.set_start_method("spawn")
queue = Queue()
args = parse_args()
manager = mp.Manager()
queue = manager.Queue()
args, parser = parse_args()
if args.all_webui:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = True
elif args.all_api:
args.openai_api = True
args.model_worker = True
args.api = True
args.api_worker = True
args.webui = False
elif args.llm_api:
args.openai_api = True
args.model_worker = True
args.api_worker = True
args.api = False
args.webui = False
dump_server_info()
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
dump_server_info(args=args)
processes = {}
if len(sys.argv) > 1:
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {"online-api": []}
def process_count():
return len(processes) + len(processes["online-api"]) - 1
if args.quiet:
log_level = "ERROR"
else:
log_level = "INFO"
controller_started = manager.Event()
if args.openai_api:
process = Process(
target=run_controller,
name=f"controller({os.getpid()})",
args=(queue, len(processes) + 1),
name=f"controller",
args=(queue, process_count() + 1, log_level, controller_started),
daemon=True,
)
process.start()
processes["controller"] = process
process = Process(
target=run_openai_api,
name=f"openai_api({os.getpid()})",
args=(queue, len(processes) + 1),
name=f"openai_api",
args=(queue, process_count() + 1),
daemon=True,
)
process.start()
processes["openai_api"] = process
if args.model_worker:
process = Process(
target=run_model_worker,
name=f"model_worker({os.getpid()})",
args=(args.model_name, args.controller_address, queue, len(processes) + 1),
daemon=True,
)
process.start()
processes["model_worker"] = process
config = get_model_worker_config(args.model_name)
if not config.get("online_api"):
process = Process(
target=run_model_worker,
name=f"model_worker - {args.model_name}",
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
processes["model_worker"] = process
if args.api_worker:
configs = get_all_model_worker_configs()
for model_name, config in configs.items():
if config.get("online_api") and config.get("worker_class"):
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name}",
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
processes["online-api"].append(process)
if args.api:
process = Process(
target=run_api_server,
name=f"API Server{os.getpid()})",
args=(queue, len(processes) + 1),
name=f"API Server",
args=(queue, process_count() + 1),
daemon=True,
)
process.start()
processes["api"] = process
if args.webui:
process = Process(
target=run_webui,
name=f"WEBUI Server{os.getpid()})",
args=(queue, len(processes) + 1),
name=f"WEBUI Server",
args=(queue, process_count() + 1),
daemon=True,
)
process.start()
processes["webui"] = process
try:
# log infors
while True:
no = queue.get()
if no == len(processes):
time.sleep(0.5)
dump_server_info(True)
break
else:
queue.put(no)
if process_count() == 0:
parser.print_help()
else:
try:
# 保证任务收到SIGINT后能够正常退出
if p:= processes.get("controller"):
p.start()
p.name = f"{p.name} ({p.pid})"
controller_started.wait()
if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for name, process in processes.items():
if name != "model_worker":
if p:= processes.get("openai_api"):
p.start()
p.name = f"{p.name} ({p.pid})"
if p:= processes.get("model_worker"):
p.start()
p.name = f"{p.name} ({p.pid})"
for p in processes.get("online-api", []):
p.start()
p.name = f"{p.name} ({p.pid})"
if p:= processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
if p:= processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
while True:
no = queue.get()
if no == process_count():
time.sleep(0.5)
dump_server_info(after_start=True, args=args)
break
else:
queue.put(no)
if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for process in processes.get("online-api", []):
process.join()
except:
if model_worker_process := processes.get("model_worker"):
model_worker_process.terminate()
for name, process in processes.items():
if name != "model_worker":
process.terminate()
for name, process in processes.items():
if name not in ["model_worker", "online-api"]:
if isinstance(p, list):
for work_process in p:
work_process.join()
else:
process.join()
except Exception as e:
# if model_worker_process := processes.pop("model_worker", None):
# model_worker_process.terminate()
# for process in processes.pop("online-api", []):
# process.terminate()
# for process in processes.values():
#
# if isinstance(process, list):
# for work_process in process:
# work_process.terminate()
# else:
# process.terminate()
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:
# Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
# .is_alive() also implicitly joins the process (good practice in linux)
# while alive_procs := [p for p in processes.values() if p.is_alive()]:
for p in processes.values():
logger.warning("Sending SIGKILL to %s", p)
# Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here
if isinstance(p, list):
for process in p:
process.kill()
else:
p.kill()
for p in processes.values():
logger.info("Process status: %s", p)
if __name__ == "__main__":
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 同步调用协程代码
loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例:
# import openai

View File

@ -1,4 +1,3 @@
from doctest import testfile
import requests
import json
import sys
@ -6,7 +5,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path
@ -112,7 +111,7 @@ def test_upload_doc(api="/knowledge_base/upload_doc"):
assert data["msg"] == f"成功上传文件 {name}"
def test_list_docs(api="/knowledge_base/list_docs"):
def test_list_files(api="/knowledge_base/list_files"):
url = api_base_url + api
print("\n获取知识库中文件列表:")
r = requests.get(url, params={"knowledge_base_name": kb})

74
tests/api/test_llm_api.py Normal file
View File

@ -0,0 +1,74 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict
from pprint import pprint
import random
def get_configured_models():
model_workers = list(FSCHAT_MODEL_WORKERS)
if "default" in model_workers:
model_workers.remove("default")
llm_dict = list(llm_model_dict)
return model_workers, llm_dict
api_base_url = api_address()
def get_running_models(api="/llm_model/list_models"):
url = api_base_url + api
r = requests.post(url)
if r.status_code == 200:
return r.json()["data"]
return []
def test_running_models(api="/llm_model/list_models"):
url = api_base_url + api
r = requests.post(url)
assert r.status_code == 200
print("\n获取当前正在运行的模型列表:")
pprint(r.json())
assert isinstance(r.json()["data"], list)
assert len(r.json()["data"]) > 0
# 不建议使用stop_model功能。按现在的实现停止了就只能手动再启动
# def test_stop_model(api="/llm_model/stop"):
# url = api_base_url + api
# r = requests.post(url, json={""})
def test_change_model(api="/llm_model/change"):
url = api_base_url + api
running_models = get_running_models()
assert len(running_models) > 0
model_workers, llm_dict = get_configured_models()
availabel_new_models = set(model_workers) - set(running_models)
if len(availabel_new_models) == 0:
availabel_new_models = set(llm_dict) - set(running_models)
availabel_new_models = list(availabel_new_models)
assert len(availabel_new_models) > 0
print(availabel_new_models)
model_name = random.choice(running_models)
new_model_name = random.choice(availabel_new_models)
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})
assert r.status_code == 200
running_models = get_running_models()
assert new_model_name in running_models

View File

@ -5,7 +5,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.model_config import BING_SUBSCRIPTION_KEY
from configs.server_config import API_SERVER, api_address
from server.utils import api_address
from pprint import pprint

View File

@ -0,0 +1,21 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from pprint import pprint
test_files = {
"ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"),
}
def test_rapidocrpdfloader():
pdf_path = test_files["ocr_test.pdf"]
from document_loaders import RapidOCRPDFLoader
loader = RapidOCRPDFLoader(pdf_path)
docs = loader.load()
pprint(docs)
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)

View File

@ -0,0 +1,21 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from pprint import pprint
test_files = {
"ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"),
}
def test_rapidocrloader():
img_path = test_files["ocr_test.jpg"]
from document_loaders import RapidOCRLoader
loader = RapidOCRLoader(img_path)
docs = loader.load()
pprint(docs)
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)

View File

View File

@ -0,0 +1,35 @@
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
from server.knowledge_base.migrate import create_tables
from server.knowledge_base.utils import KnowledgeFile
kbService = FaissKBService("test")
test_kb_name = "test"
test_file_name = "README.md"
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
search_content = "如何启动api服务"
def test_init():
create_tables()
def test_create_db():
assert kbService.create_kb()
def test_add_doc():
assert kbService.add_doc(testKnowledgeFile)
def test_search_db():
result = kbService.search_docs(search_content)
assert len(result) > 0
def test_delete_doc():
assert kbService.delete_doc(testKnowledgeFile)
def test_delete_db():
assert kbService.drop_kb()

View File

@ -0,0 +1,31 @@
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
from server.knowledge_base.migrate import create_tables
from server.knowledge_base.utils import KnowledgeFile
kbService = MilvusKBService("test")
test_kb_name = "test"
test_file_name = "README.md"
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
search_content = "如何启动api服务"
def test_init():
create_tables()
def test_create_db():
assert kbService.create_kb()
def test_add_doc():
assert kbService.add_doc(testKnowledgeFile)
def test_search_db():
result = kbService.search_docs(search_content)
assert len(result) > 0
def test_delete_doc():
assert kbService.delete_doc(testKnowledgeFile)

View File

@ -0,0 +1,31 @@
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
from server.knowledge_base.migrate import create_tables
from server.knowledge_base.utils import KnowledgeFile
kbService = PGKBService("test")
test_kb_name = "test"
test_file_name = "README.md"
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name)
search_content = "如何启动api服务"
def test_init():
create_tables()
def test_create_db():
assert kbService.create_kb()
def test_add_doc():
assert kbService.add_doc(testKnowledgeFile)
def test_search_db():
result = kbService.search_docs(search_content)
assert len(result) > 0
def test_delete_doc():
assert kbService.delete_doc(testKnowledgeFile)

BIN
tests/samples/ocr_test.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.9 KiB

BIN
tests/samples/ocr_test.pdf Normal file

Binary file not shown.

View File

@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu
from webui_pages import *
import os
from configs import VERSION
from server.utils import api_address
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
api = ApiRequest(base_url=api_address())
if __name__ == "__main__":
st.set_page_config(

View File

@ -1,10 +1,14 @@
import streamlit as st
from configs.server_config import FSCHAT_MODEL_WORKERS
from webui_pages.utils import *
from streamlit_chatbox import *
from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
from typing import List, Dict
import os
from configs.model_config import llm_model_dict, LLM_MODEL
from server.utils import get_model_worker_config
from typing import List, Dict
chat_box = ChatBox(
assistant_avatar=os.path.join(
@ -59,9 +63,39 @@ def dialogue_page(api: ApiRequest):
on_change=on_mode_change,
key="dialogue_mode",
)
history_len = st.number_input("历史对话轮数:", 0, 10, 3)
# todo: support history len
def on_llm_change():
st.session_state["prev_llm_model"] = llm_model
def llm_model_format_func(x):
if x in running_models:
return f"{x} (Running)"
return x
running_models = api.list_running_models()
config_models = api.list_config_models()
for x in running_models:
if x in config_models:
config_models.remove(x)
llm_models = running_models + config_models
if "prev_llm_model" not in st.session_state:
index = llm_models.index(LLM_MODEL)
else:
index = 0
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
# key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")):
with st.spinner(f"正在加载模型: {llm_model}"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["prev_llm_model"] = llm_model
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
def on_kb_change():
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
@ -75,7 +109,7 @@ def dialogue_page(api: ApiRequest):
on_change=on_kb_change,
key="selected_kb",
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
@ -87,13 +121,13 @@ def dialogue_page(api: ApiRequest):
options=search_engine_list,
index=search_engine_list.index("duckduckgo") if "duckduckgo" in search_engine_list else 0,
)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3)
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter "
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter "
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
history = get_messages_history(history_len)
@ -101,7 +135,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
r = api.chat_chat(prompt, history)
r = api.chat_chat(prompt, history=history, model=llm_model)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
@ -116,7 +150,7 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"),
])
text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
@ -129,8 +163,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="网络搜索结果"),
])
text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
if error_msg := check_error_msg(d): # check whether error occured
for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
else:
text += d["answer"]

View File

@ -4,7 +4,7 @@ from st_aggrid import AgGrid, JsCode
from st_aggrid.grid_options_builder import GridOptionsBuilder
import pandas as pd
from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple
from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE
import os
@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
# 上传文件
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
files = st.file_uploader("上传知识文件",
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)
@ -152,7 +152,7 @@ def knowledge_base_page(api: ApiRequest):
# 知识库详情
# st.info("请选择文件,点击按钮进行操作。")
doc_details = pd.DataFrame(get_kb_doc_details(kb))
doc_details = pd.DataFrame(get_kb_file_details(kb))
if not len(doc_details):
st.info(f"知识库 `{kb}` 中暂无文件")
else:
@ -160,7 +160,7 @@ def knowledge_base_page(api: ApiRequest):
st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作")
doc_details.drop(columns=["kb_name"], inplace=True)
doc_details = doc_details[[
"No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db",
"No", "file_name", "document_loader", "docs_count", "in_folder", "in_db",
]]
# doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×")
# doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×")
@ -172,7 +172,8 @@ def knowledge_base_page(api: ApiRequest):
# ("file_ext", "文档类型"): {},
# ("file_version", "文档版本"): {},
("document_loader", "文档加载器"): {},
("text_splitter", "分词器"): {},
("docs_count", "文档数量"): {},
# ("text_splitter", "分词器"): {},
# ("create_time", "创建时间"): {},
("in_folder", "源文件"): {"cellRenderer": cell_renderer},
("in_db", "向量库"): {"cellRenderer": cell_renderer},
@ -244,7 +245,6 @@ def knowledge_base_page(api: ApiRequest):
cols = st.columns(3)
# todo: freezed
if cols[0].button(
"依据源文件重建向量库",
# help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。",
@ -258,7 +258,7 @@ def knowledge_base_page(api: ApiRequest):
if msg := check_error_msg(d):
st.toast(msg)
else:
empty.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
empty.progress(d["finished"] / d["total"], d["msg"])
st.experimental_rerun()
if cols[2].button(

View File

@ -6,11 +6,14 @@ from configs.model_config import (
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
llm_model_dict,
HISTORY_LEN,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
logger,
)
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx
import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn
@ -20,21 +23,12 @@ import json
from io import BytesIO
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail
from server.utils import run_async, iter_over_async
from server.utils import run_async, iter_over_async, set_httpx_timeout
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
def set_httpx_timeout(timeout=60.0):
'''
设置httpx默认timeout到60秒
httpx默认timeout是5秒在请求LLM回答时不够用
'''
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
from pprint import pprint
KB_ROOT_PATH = Path(KB_ROOT_PATH)
@ -50,7 +44,7 @@ class ApiRequest:
def __init__(
self,
base_url: str = "http://127.0.0.1:7861",
timeout: float = 60.0,
timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly
):
self.base_url = base_url
@ -224,9 +218,17 @@ class ApiRequest:
try:
with response as r:
for chunk in r.iter_text(None):
if as_json and chunk:
yield json.loads(chunk)
elif chunk.strip():
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
data = json.loads(chunk)
pprint(data, depth=1)
yield data
except Exception as e:
logger.error(f"接口返回json错误 {chunk}’。错误信息是:{e}")
else:
print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。"
@ -274,6 +276,9 @@ class ApiRequest:
return self._fastapi_stream2generator(response)
else:
data = msg.dict(exclude_unset=True, exclude_none=True)
print(f"received input message:")
pprint(data)
response = self.post(
"/chat/fastchat",
json=data,
@ -286,6 +291,7 @@ class ApiRequest:
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@ -298,8 +304,12 @@ class ApiRequest:
"query": query,
"history": history,
"stream": stream,
"model_name": model,
}
print(f"received input message:")
pprint(data)
if no_remote_api:
from server.chat.chat import chat
response = chat(**data)
@ -316,6 +326,7 @@ class ApiRequest:
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@ -331,9 +342,13 @@ class ApiRequest:
"score_threshold": score_threshold,
"history": history,
"stream": stream,
"model_name": model,
"local_doc_url": no_remote_api,
}
print(f"received input message:")
pprint(data)
if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(**data)
@ -352,6 +367,7 @@ class ApiRequest:
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
stream: bool = True,
model: str = LLM_MODEL,
no_remote_api: bool = None,
):
'''
@ -365,8 +381,12 @@ class ApiRequest:
"search_engine_name": search_engine_name,
"top_k": top_k,
"stream": stream,
"model_name": model,
}
print(f"received input message:")
pprint(data)
if no_remote_api:
from server.chat.search_engine_chat import search_engine_chat
response = search_engine_chat(**data)
@ -473,18 +493,18 @@ class ApiRequest:
no_remote_api: bool = None,
):
'''
对应api.py/knowledge_base/list_docs接口
对应api.py/knowledge_base/list_files接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.knowledge_base.kb_doc_api import list_docs
response = run_async(list_docs(knowledge_base_name))
from server.knowledge_base.kb_doc_api import list_files
response = run_async(list_files(knowledge_base_name))
return response.data
else:
response = self.get(
"/knowledge_base/list_docs",
"/knowledge_base/list_files",
params={"knowledge_base_name": knowledge_base_name}
)
data = self._check_httpx_json_response(response)
@ -633,6 +653,84 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
def list_running_models(self, controller_address: str = None):
'''
获取Fastchat中正运行的模型列表
'''
r = self.post(
"/llm_model/list_models",
)
return r.json().get("data", [])
def list_config_models(self):
'''
获取configs中配置的模型列表
'''
return list(llm_model_dict.keys())
def stop_llm_model(
self,
model_name: str,
controller_address: str = None,
):
'''
停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
data = {
"model_name": model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/stop",
json=data,
)
return r.json()
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
):
'''
向fastchat controller请求切换LLM模型
'''
if not model_name or not new_model_name:
return
if new_model_name == model_name:
return {
"code": 200,
"msg": "什么都不用做"
}
running_models = self.list_running_models()
if model_name not in running_models:
return {
"code": 500,
"msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}"
}
config_models = self.list_config_models()
if new_model_name not in config_models:
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
}
data = {
"model_name": model_name,
"new_model_name": new_model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/change",
json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
'''