Merge branch 'dev' into pre-release

This commit is contained in:
imClumsyPanda 2023-09-15 13:47:25 +08:00
commit 42fba7ef90
64 changed files with 2970 additions and 1271 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ __pycache__/
/configs/*.py
.vscode/
.pytest_cache/
*.bak

194
README.md
View File

@ -14,8 +14,7 @@
* [2. 下载模型至本地](README.md#2.-下载模型至本地)
* [3. 设置配置项](README.md#3.-设置配置项)
* [4. 知识库初始化与迁移](README.md#4.-知识库初始化与迁移)
* [5. 一键启动API服务或WebUI服务](README.md#6.-一键启动)
* [6. 分步启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI)
* [5. 一键启动 API 服务或 Web UI](README.md#5.-一键启动-API-服务或-Web-UI)
* [常见问题](README.md#常见问题)
* [路线图](README.md#路线图)
* [项目交流群](README.md#项目交流群)
@ -78,7 +77,9 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
### LLM 模型支持
本项目最新版本中基于 [FastChat](https://github.com/lm-sys/FastChat) 进行本地 LLM 模型接入,支持模型如下:
本项目最新版本中支持接入**本地模型**与**在线 LLM API**。
本地 LLM 模型接入基于 [FastChat](https://github.com/lm-sys/FastChat) 实现,支持模型如下:
- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
- Vicuna, Alpaca, LLaMA, Koala
@ -109,12 +110,27 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta)
- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and others
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
- [all models of OpenOrca](https://huggingface.co/Open-Orca)
- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2)
- [VMware's OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
- 任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)
- 在以上模型基础上训练的任何 [Peft](https://github.com/huggingface/peft) 适配器。为了激活,模型路径中必须有 `peft` 。注意如果加载多个peft模型你可以通过在任何模型工作器中设置环境变量 `PEFT_SHARE_BASE_WEIGHTS=true` 来使它们共享基础模型的权重。
以上模型支持列表可能随 [FastChat](https://github.com/lm-sys/FastChat) 更新而持续更新,可参考 [FastChat 已支持模型列表](https://github.com/lm-sys/FastChat/blob/main/docs/model_support.md)。
除本地模型外,本项目也支持直接接入 OpenAI API具体设置可参考 `configs/model_configs.py.example` 中的 `llm_model_dict``openai-chatgpt-3.5` 配置信息。
除本地模型外,本项目也支持直接接入 OpenAI API、智谱AI等在线模型具体设置可参考 `configs/model_configs.py.example` 中的 `llm_model_dict` 的配置信息。
在线 LLM 模型目前已支持:
- [ChatGPT](https://api.openai.com)
- [智谱AI](http://open.bigmodel.cn)
- [MiniMax](https://api.minimax.chat)
- [讯飞星火](https://xinghuo.xfyun.cn)
- [百度千帆](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
项目中默认使用的 LLM 类型为 `THUDM/chatglm2-6b`,如需使用其他 LLM 类型,请在 [configs/model_config.py] 中对 `llm_model_dict``LLM_MODEL` 进行修改。
### Embedding 模型支持
@ -139,8 +155,34 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
项目中默认使用的 Embedding 类型为 `moka-ai/m3e-base`,如需使用其他 Embedding 类型,请在 [configs/model_config.py] 中对 `embedding_model_dict``EMBEDDING_MODEL` 进行修改。
---
### Text Splitter 个性化支持
本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的 Text Splitter 分词器以及基于此改进的自定义分词器,已支持的 Text Splitter 类型如下:
- CharacterTextSplitter
- LatexTextSplitter
- MarkdownHeaderTextSplitter
- MarkdownTextSplitter
- NLTKTextSplitter
- PythonCodeTextSplitter
- RecursiveCharacterTextSplitter
- SentenceTransformersTokenTextSplitter
- SpacyTextSplitter
已经支持的定制分词器如下:
- [AliTextSplitter](text_splitter/ali_text_splitter.py)
- [ChineseRecursiveTextSplitter](text_splitter/chinese_recursive_text_splitter.py)
- [ChineseTextSplitter](text_splitter/chinese_text_splitter.py)
项目中默认使用的 Text Splitter 类型为 `ChineseRecursiveTextSplitter`,如需使用其他 Text Splitter 类型,请在 [configs/model_config.py] 中对 `text_splitter_dict``TEXT_SPLITTER` 进行修改。
关于如何使用自定义分词器和贡献自己的分词器,可以参考[Text Splitter 贡献说明](docs/splitter.md)。
## Docker 部署
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
@ -212,6 +254,17 @@ embedding_model_dict = {
}
```
- 请确认本地分词器路径是否已经填写,如:
```python
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", ## 选择tiktoken则使用openai的方法,不填写则默认为字符长度切割方法。
"tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。
}
}
```
如果你选择使用OpenAI的Embedding模型请将模型的 ``key``写入 `embedding_model_dict`中。使用该模型你需要能够访问OpenAI官的API或设置代理。
### 4. 知识库初始化与迁移
@ -229,7 +282,7 @@ embedding_model_dict = {
$ python init_database.py --recreate-vs
```
### 5. 一键启动API 服务或 Web UI
### 5. 一键启动 API 服务或 Web UI
#### 5.1 启动命令
@ -307,139 +360,12 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
2. webui启动界面示例
- Web UI 对话界面:
![img](img/webui_0813_0.png)
- Web UI 知识库管理页面:
![](img/webui_0813_1.png)
### 6 分步启动 API 服务或 Web UI
![img](img/webui_0915_0.png)
注意:如使用了一键启动方式,可忽略本节。
#### 6.1 启动 LLM 服务
如需使用开源模型进行本地部署,需首先启动 LLM 服务,启动方式分为三种:
- [基于多进程脚本 llm_api.py 启动 LLM 服务](README.md#5.1.1-基于多进程脚本-llm_api.py-启动-LLM-服务)
- [基于命令行脚本 llm_api_stale.py 启动 LLM 服务](README.md#5.1.2-基于命令行脚本-llm_api_stale.py-启动-LLM-服务)
- [PEFT 加载](README.md#5.1.3-PEFT-加载)
三种方式只需选择一个即可,具体操作方式详见 5.1.1 - 5.1.3。
如果启动在线的API服务如 OPENAI 的 API 接口),则无需启动 LLM 服务,即 5.1 小节的任何命令均无需启动。
##### 6.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务
在项目根目录下,执行 [server/llm_api.py](server/llm_api.py) 脚本启动 **LLM 模型**服务:
```shell
$ python server/llm_api.py
```
项目支持多卡加载,需在 llm_api.py 中的 create_model_worker_app 函数中,修改如下三个参数:
```python
gpus=None,
num_gpus=1,
max_gpu_memory="20GiB"
```
其中,`gpus` 控制使用的显卡的ID如果 "0,1";
`num_gpus` 控制使用的卡数;
`max_gpu_memory` 控制每个卡使用的显存容量。
##### 6.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务
⚠️ **注意:**
**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wsl;**
**2.加载非默认模型需要用命令行参数--model-path-address指定模型不会读取model_config.py配置;**
在项目根目录下,执行 [server/llm_api_stale.py](server/llm_api_stale.py) 脚本启动 **LLM 模型**服务:
```shell
$ python server/llm_api_stale.py
```
该方式支持启动多个worker示例启动方式
```shell
$ python server/llm_api_stale.py --model-path-address model1@host1@port1 model2@host2@port2
```
如果出现server端口占用情况需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口:
```shell
$ python server/llm_api_stale.py --server-port 8887
```
如果要启动多卡加载,示例命令如下:
```shell
$ python server/llm_api_stale.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB
```
以如上方式启动LLM服务会以nohup命令在后台运行 FastChat 服务,如需停止服务,可以运行如下命令:
```shell
$ python server/llm_api_shutdown.py --serve all
```
亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`]
##### 6.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等)
本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.jsonpeft 路径下包含 model.bin 格式的 PEFT 权重。
详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822)
![image](https://github.com/chatchat-space/Langchain-Chatchat/assets/22924096/4e056c1c-5c4b-4865-a1af-859cd58a625d)
#### 6.2 启动 API 服务
本地部署情况下,按照 [5.1 节](README.md#5.1-启动-LLM-服务)**启动 LLM 服务后**,再执行 [server/api.py](server/api.py) 脚本启动 **API** 服务;
在线调用API服务的情况下直接执执行 [server/api.py](server/api.py) 脚本启动 **API** 服务;
调用命令示例:
```shell
$ python server/api.py
```
启动 API 服务后,可访问 `localhost:7861``{API 所在服务器 IP}:7861` FastAPI 自动生成的 docs 进行接口查看与测试。
- FastAPI docs 界面
![](img/fastapi_docs_020_0.png)
#### 6.3 启动 Web UI 服务
按照 [5.2 节](README.md#5.2-启动-API-服务)**启动 API 服务后**,执行 [webui.py](webui.py) 启动 **Web UI** 服务(默认使用端口 `8501`
```shell
$ streamlit run webui.py
```
使用 Langchain-Chatchat 主题色启动 **Web UI** 服务(默认使用端口 `8501`
```shell
$ streamlit run webui.py --theme.base "light" --theme.primaryColor "#165dff" --theme.secondaryBackgroundColor "#f5f5f5" --theme.textColor "#000000"
```
或使用以下命令指定启动 **Web UI** 服务并指定端口号
```shell
$ streamlit run webui.py --server.port 666
```
- Web UI 对话界面:
![](img/webui_0813_0.png)
- Web UI 知识库管理页面:
![](img/webui_0813_1.png)
![](img/webui_0915_1.png)
---

363
README_en.md Normal file
View File

@ -0,0 +1,363 @@
![](img/logo-long-chatchat-trans-v2.png)
**LangChain-Chatchat** (former Langchain-ChatGLM): A LLM application aims to implement knowledge- and search engineer- based QA based on Langchain and open-source or remote LLM api.
## Content
* Introduction
* Change Log
* Docker Deployment
* Deployment
* Enviroment Preresiquisite
* Preparing Depolyment Enviroment
* Downloading model to local disk(for offline deployment only)
* Setting Configuration
* Knowledge Base Migration
* Luanching API Service or WebUI with One Command
* Luanching API Service or WebUI step-by-step
* FAQ
* Roadmap
* Wechat Group
---
## Introduction
🤖️ A Q&A application based on local knowledge base implemented using the idea of [langchain](https://github.com/hwchase17/langchain). The goal is to build a KBQA(Knowledge based Q&A) solution that is friendly to Chinese scenarios and open source models and can run both offline and online.
💡 Inspried by [document.ai](https://github.com/GanymedeNil/document.ai) and [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) , we build a local knowledge base question answering application that can be implemented using an open source model or remote LLM api throughout the process. In the latest version of this project, [FastChat](https://github.com/lm-sys/FastChat) is used to access Vicuna, Alpaca, LLaMA, Koala, RWKV and many other models. Relying on [langchain](https:// github.com/langchain-ai/langchain) , this project supports calling services through the API provided based on [FastAPI](https://github.com/tiangolo/fastapi), or using the WebUI based on [Streamlit](https://github.com /streamlit/streamlit) .
✅ Relying on the open source LLM and Embedding models, this project can realize full-process **offline private deployment**. At the same time, this project also supports the call of OpenAI GPT API- and Zhipu API, and will continue to expand the access to various models and remote APIs in the future.
⛓️ The implementation principle of this project is shown in the graph below. The main process includes: loading files -> reading text -> text segmentation -> text vectorization -> question vectorization -> matching the `top-k` most similar to the question vector in the text vector -> The matched text is added to `prompt `as context and question -> submitted to `LLM` to generate an answer.
📺[video introdution](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514)
![实现原理图](img/langchain+chatglm.png)
The main process analysis from the aspect of document process:
![实现原理图2](img/langchain+chatglm2.png)
🚩 The training or fined-tuning are not involved in the project, but still, one always can improve performance by do these.
🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0) is supported, and in v7 the codes are update to v0.2.3.
🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0)
💻 Run Docker with one command:
```shell
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0
```
---
## Change Log
plese refer to [version change log](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)
### Current Features
* **Consistent LLM Service based on FastChat**. The project use [FastChat](https://github.com/lm-sys/FastChat) to provide the API service of the open source LLM models and access it in the form of OpenAI API interface to improve the loading effect of the LLM model;
* **Chain and Agent based on Langchian**. Use the existing Chain implementation in [langchain](https://github.com/langchain-ai/langchain) to facilitate subsequent access to different types of Chain, and will test Agent access;
* **Full fuction API service based on FastAPI**. All interfaces can be tested in the docs automatically generated by [FastAPI](https://github.com/tiangolo/fastapi), and all dialogue interfaces support streaming or non-streaming output through parameters. ;
* **WebUI service based on Streamlit**. With [Streamlit](https://github.com/streamlit/streamlit), you can choose whether to start WebUI based on API services, add session management, customize session themes and switch, and will support different display of content forms of output in the future;
* **Abundant open source LLM and Embedding models**. The default LLM model in the project is changed to [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b), and the default Embedding model is changed to [moka-ai/m3e-base](https:// huggingface.co/moka-ai/m3e-base), the file loading method and the paragraph division method have also been adjusted. In the future, context expansion will be re-implemented and optional settings will be added;
* **Multiply vector libraries**. The project has expanded support for different types of vector libraries. Including [FAISS](https://github.com/facebookresearch/faiss), [Milvus](https://github.com/milvus -io/milvus), and [PGVector](https://github.com/pgvector/pgvector);
* **Varied Search engines**. We provide two search engines now: Bing and DuckDuckGo. DuckDuckGo search does not require configuring an API Key and can be used directly in environments with access to foreign services.
## Supported Models
The default LLM model in the project is changed to [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b), and the default Embedding model is changed to [moka-ai/m3e-base](https:// huggingface.co/moka-ai/m3e-base).
### Supported LLM models
The project use [FastChat](https://github.com/lm-sys/FastChat) to provide the API service of the open source LLM models, supported models include:
- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
- Vicuna, Alpaca, LLaMA, Koala
- [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven)
- [camel-ai/CAMEL-13B-Combined-Data](https://huggingface.co/camel-ai/CAMEL-13B-Combined-Data)
- [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b)
- [FreedomIntelligence/phoenix-inst-chat-7b](https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b)
- [h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b)
- [lcw99/polyglot-ko-12.8b-chang-instruct-chat](https://huggingface.co/lcw99/polyglot-ko-12.8b-chang-instruct-chat)
- [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5)
- [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat)
- [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT)
- [nomic-ai/gpt4all-13b-snoozy](https://huggingface.co/nomic-ai/gpt4all-13b-snoozy)
- [NousResearch/Nous-Hermes-13b](https://huggingface.co/NousResearch/Nous-Hermes-13b)
- [openaccess-ai-collective/manticore-13b-chat-pyg](https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg)
- [OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5](https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5)
- [project-baize/baize-v2-7b](https://huggingface.co/project-baize/baize-v2-7b)
- [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b)
- [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b)
- [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
- [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
- [tiiuae/falcon-40b](https://huggingface.co/tiiuae/falcon-40b)
- [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged)
- [togethercomputer/RedPajama-INCITE-7B-Chat](https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat)
- [WizardLM/WizardLM-13B-V1.0](https://huggingface.co/WizardLM/WizardLM-13B-V1.0)
- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta)
- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and other models of FlagAlpha
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
- [all models of OpenOrca](https://huggingface.co/Open-Orca)
- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2)
- [VMware's OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
* Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)(任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b))
* Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a model above. To activate, must have `peft` in the model path. Note: If loading multiple peft models, you can have them share the base model weights by setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model worker.
Please refer to `llm_model_dict` in `configs.model_configs.py.example` to invoke OpenAI API.
### Supported Embedding models
Following models are tested by developers with Embedding class of [HuggingFace](https://huggingface.co/models?pipeline_tag=sentence-similarity):
- [moka-ai/m3e-small](https://huggingface.co/moka-ai/m3e-small)
- [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base)
- [moka-ai/m3e-large](https://huggingface.co/moka-ai/m3e-large)
- [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh)
- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)
- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)
- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct)
- [shibing624/text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence)
- [shibing624/text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase)
- [shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
- [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese)
- [shibing624/text2vec-bge-large-chinese](https://huggingface.co/shibing624/text2vec-bge-large-chinese)
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
---
## Docker image
🐳 Docker image path: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0)`
```shell
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0
```
- The image size of this version is `33.9GB`, using `v0.2.0`, with `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` as the base image
- This version has a built-in `embedding` model: `m3e-large`, built-in `chatglm2-6b-32k`
- This version is designed to facilitate one-click deployment. Please make sure you have installed the NVIDIA driver on your Linux distribution.
- Please note that you do not need to install the CUDA toolkit on the host system, but you need to install the `NVIDIA Driver` and the `NVIDIA Container Toolkit`, please refer to the [Installation Guide](https://docs.nvidia.com/datacenter/cloud -native/container-toolkit/latest/install-guide.html)
- It takes a certain amount of time to pull and start for the first time. When starting for the first time, please refer to the figure below to use `docker logs -f <container id>` to view the log.
- If the startup process is stuck in the `Waiting..` step, it is recommended to use `docker exec -it <container id> bash` to enter the `/logs/` directory to view the corresponding stage logs
---
## Deployment
### Enviroment Preresiquisite
The project is tested under Python3.8-python 3.10, CUDA 11.0-CUDA11.7, Windows, macOS of ARM architecture, and Linux platform.
### 1. Preparing Depolyment Enviroment
Please refer to [install.md](docs/INSTALL.md)
### 2. Downloading model to local disk
**For offline deployment only!**
If you want to run this project in a local or offline environment, you need to first download the models required for the project to your local computer. Usually the open source LLM and Embedding models can be downloaded from [HuggingFace](https://huggingface.co/models).
Take the LLM model [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) and Embedding model [moka-ai/m3e-base](https://huggingface. co/moka-ai/m3e-base) for example:
To download the model, you need to [install Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), and then run:
```Shell
$ git clone https://huggingface.co/THUDM/chatglm2-6b
$ git clone https://huggingface.co/moka-ai/m3e-base
```
### 3. Setting Configuration
Copy the model-related parameter configuration template file [configs/model_config.py.example](configs/model_config.py.example) and save it in the `./configs` path under the project path, and rename it to `model_config.py`.
Copy the service-related parameter configuration template file [configs/server_config.py.example](configs/server_config.py.example) to save in the `./configs` path under the project path, and rename it to `server_config.py`.
Before starting to execute Web UI or command line interaction, please check whether each model parameter in `configs/model_config.py` and `configs/server_config.py` meets the requirements.
* Please confirm that the path to local LLM model and embedding model have been written in `llm_dict` of `configs/model_config.py`, here is an example:
* If you choose to use OpenAI's Embedding model, please write the model's ``key`` into `embedding_model_dict`. To use this model, you need to be able to access the OpenAI official API, or set up a proxy.
```python
llm_model_dict={
"chatglm2-6b": {
"local_model_path": "/Users/xxx/Downloads/chatglm2-6b",
"api_base_url": "http://localhost:8888/v1", # "name"修改为 FastChat 服务中的"api_base_url"
"api_key": "EMPTY"
},
}
```
```python
embedding_model_dict = {
"m3e-base": "/Users/xxx/Downloads/m3e-base",
}
```
### 4. Knowledge Base Migration
The knowledge base information is stored in the database, please initialize the database before running the project (we strongly recommend one back up the knowledge files before performing operations).
- If you migrate from `0.1.x`, for the established knowledge base, please confirm that the vector library type and Embedding model of the knowledge base are consistent with the default settings in `configs/model_config.py`, if there is no change, simply add the existing repository information to the database with the following command:
```shell
$ python init_database.py
```
- If you are a beginner of the project whose knowledge base has not been established, or the knowledge base type and embedding model in the configuration file have changed, or the previous vector library did not enable `normalize_L2`, you need the following command to initialize or rebuild the knowledge base:
```shell
$ python init_database.py --recreate-vs
```
### 5. Luanching API Service or WebUI with One Command
#### 5.1 Command
The script is `startuppy`, you can luanch all fastchat related, API,WebUI service with is, here is an example:
```shell
$ python startup.py -a
```
optional args including: `-a(or --all-webui), --all-api, --llm-api, -c(or --controller),--openai-api, -m(or --model-worker), --api, --webui`, where:
* `--all-webui` means to launch all related services of WEBUI
* `--all-api` means to launch all related services of API
* `--llm-api` means to launch all related services of FastChat
* `--openai-api` means to launch controller and openai-api-server of FastChat only
* `model-worker` means to launch model worker of FastChat only
* any other optional arg is to launch one particular function only
#### 5.2 Launch none-default model
If you want to specify a none-default model, use `--model-name` arg, here is a example:
```shell
$ python startup.py --all-webui --model-name Qwen-7B-Chat
```
#### 5.3 Load model with multi-gpus
If you want to load model with multi-gpus, then the following three parameters in `startup.create_model_worker_app` should be changed:
```python
gpus=None,
num_gpus=1,
max_gpu_memory="20GiB"
```
where:
* `gpus` is about specifying the gpus' ID, such as '0,1';
* `num_gpus` is about specifying the number of gpus to be used under `gpus`;
* `max_gpu_memory` is about specifying the gpu memory of every gpu.
note:
* These parameters now can be specified by `server_config.FSCHST_MODEL_WORKERD`.
* In some extreme senses, `gpus` doesn't work, then one should specify the used gpus with environment variable `CUDA_VISIBLE_DEVICES`, here is an example:
```shell
CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
```
#### 5.4 Load PEFT
Including lora,p-tuning,prefix tuning, prompt tuning,ia3
This project loads the LLM service based on FastChat, so one must load the PEFT in a FastChat way, that is, ensure that the word `peft` must be in the path name, the name of the configuration file must be `adapter_config.json`, and the path contains PEFT weights in `.bin` format. The peft path is specified in `args.model_names` of the `create_model_worker_app` function in `startup.py`, and enable the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` parameter.
If the above method fails, you need to start standard fastchat service step by step. Step-by-step procedure could be found Section 6. For further steps, please refer to [Model invalid after loading lora fine-tuning](https://github. com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822).
#### **5.5 Some Notes**
1. **The `startup.py` uses multi-process mode to start the services of each module, which may cause printing order problems. Please wait for all services to be initiated before calling, and call the service according to the default or specified port (default LLM API service port: `127.0.0.1:8888 `, default API service port:`127.0.0.1:7861 `, default WebUI service port: `127.0.0.1: 8501`)**
2. **The startup time of the service differs across devices, usually it takes 3-10 minutes. If it does not start for a long time, please go to the `./logs` directory to monitor the logs and locate the problem.**
3. **Using ctrl+C to exit on Linux may cause orphan processes due to the multi-process mechanism of Linux. You can exit through `shutdown_all.sh`**
#### 5.6 Interface Examples
The API, chat interface of WebUI, and knowledge management interface of WebUI are list below respectively.
1. FastAPI docs
![](img/fastapi_docs_020_0.png)
2. Chat Interface of WebUI
- Dialogue interface of WebUI
![img](img/webui_0915_0.png)
- Knowledge management interface of WebUI
![img](img/webui_0915_1.png)
### 6 Luanching API Service or WebUI step-by-step
**The developers will depreciate step-by-step procudure in the future one or two version, feel free to ignore this part.**
## FAQ
Please refer to [FAQ](docs/FAQ.md)
---
## Roadmap
- [X] Langchain applications
- [X] Load local documents
- [X] Unstructed documents
- [X] .md
- [X] .txt
- [X] .docx
- [ ] Structed documents
- [X] .csv
- [ ] .xlsx
- [ ] TextSplliter and Retriever
- [ ] multipy TextSplitter
- [ ] ChineseTextSplitter
- [ ] Recontructed Context Retriever
- [ ] Webpage
- [ ] SQL
- [ ] Knowledge Database
- [X] Search Engines
- [X] Bing
- [X] DuckDuckGo
- [ ] Agent
- [X] LLM Models
- [X] [FastChat](https://github.com/lm-sys/fastchat) -based LLM Models
- [ ] Mutiply Remote LLM API
- [X] Embedding Models
- [X] HuggingFace -based Embedding models
- [ ] Mutiply Remote Embedding API
- [X] 基于 FastAPI -based API
- [X] Web UI
- [X] Streamlit -based Web UI
---
## WeChat Group QR Code
<img src="img/qr_code_59.jpg" alt="二维码" width="300" height="300" />
**WeChat Group**

View File

@ -1,4 +1,4 @@
from .model_config import *
from .server_config import *
VERSION = "v0.2.4-preview"
VERSION = "v0.2.4"

View File

@ -5,6 +5,8 @@ LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(mes
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 是否显示详细日志
log_verbose = False
# 在以下字典中修改属性值以指定本地embedding模型存储位置
@ -73,19 +75,43 @@ llm_model_dict = {
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
"gpt-3.5-turbo": {
"api_base_url": "https://api.openai.com/v1",
"api_key": os.environ.get("OPENAI_API_KEY"),
"openai_proxy": os.environ.get("OPENAI_PROXY")
"api_key": "",
"openai_proxy": ""
},
# 线上模型。当前支持智谱AI。
# 如果没有设置有效的local_model_path则认为是在线模型API。
# 请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn
"chatglm-api": {
"zhipu-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": os.environ.get("ZHIPUAI_API_KEY"),
"api_key": "",
"provider": "ChatGLMWorker",
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
},
"minimax-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"group_id": "",
"api_key": "",
"is_pro": False,
"provider": "MiniMaxWorker",
},
"xinghuo-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"APPID": "",
"APISecret": "",
"api_key": "",
"is_v2": False,
"provider": "XingHuoWorker",
},
# 百度千帆 API申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
"qianfan-api": {
"version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" 更多的见文档模型支持列表中千帆部分。
"version_url": "", # 可以不填写version直接填写在千帆申请模型发布的API地址
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": "",
"secret_key": "",
"provider": "ErnieWorker",
}
}
# LLM 名称
@ -94,9 +120,51 @@ LLM_MODEL = "chatglm2-6b"
# 历史对话轮数
HISTORY_LEN = 3
# LLM通用对话参数
TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
LLM_DEVICE = "auto"
# TextSplitter
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "",
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# TEXT_SPLITTER 名称
TEXT_SPLITTER = "ChineseRecursiveTextSplitter"
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 0
# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
if not os.path.exists(LOG_PATH):
@ -104,12 +172,14 @@ if not os.path.exists(LOG_PATH):
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
# 可选向量库类型及对应配置
kbs_config = {
"faiss": {
@ -132,20 +202,14 @@ DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量
CACHED_VS_NUM = 1
# 知识库中单段文本长度
CHUNK_SIZE = 250
# 知识库中相邻文本重合长度
OVERLAP_SIZE = 50
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 5
VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 5
SEARCH_ENGINE_TOP_K = 3
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")

View File

@ -59,13 +59,23 @@ FSCHAT_MODEL_WORKERS = {
# "limit_worker_concurrency": 5,
# "stream_interval": 2,
# "no_register": False,
# "embed_in_truncate": False,
},
"baichuan-7b": { # 使用default中的IP和端口
"device": "cpu",
},
"chatglm-api": { # 请为每个在线API设置不同的端口
"zhipu-api": { # 请为每个在线API设置不同的端口
"port": 20003,
},
"minimax-api": { # 请为每个在线API设置不同的端口
"port": 20004,
},
"xinghuo-api": { # 请为每个在线API设置不同的端口
"port": 20005,
},
"qianfan-api": {
"port": 20006,
},
}
# fastchat multi model worker server

24
docs/splitter.md Normal file
View File

@ -0,0 +1,24 @@
## 如何自定义分词器
### 在哪里写,哪些文件要改
1. 在```text_splitter```文件夹下新建一个文件,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器,如下所示:
```python
from .my_splitter import MySplitter
```
2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示:
```python
MySplitter: {
"source": "huggingface", ## 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "your tokenizer", #如果选择huggingface则使用huggingface的方法部分tokenizer需要从Huggingface下载
}
TEXT_SPLITTER = "MySplitter"
```
完成上述步骤后,就能使用自己的分词器了。
### 如何贡献您的分词器
1. 将您的分词器所在的代码文件放在```text_splitter```文件夹下,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器。
2. 发起PR并说明您的分词器面向的场景或者改进之处。我们非常期待您能举例一个具体的应用场景。
3. 在Readme.md中添加您的分词器的使用方法和支持说明。

View File

@ -1,5 +1,6 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
import tqdm
class RapidOCRPDFLoader(UnstructuredFileLoader):
@ -11,7 +12,14 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
ocr = RapidOCR()
doc = fitz.open(filepath)
resp = ""
for page in doc:
b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0")
for i, page in enumerate(doc):
# 更新描述
b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i))
# 立即显示进度条更新结果
b_unit.refresh()
# TODO: 依据文本与图片顺序调整处理方式
text = page.get_text("")
resp += text + "\n"
@ -24,6 +32,9 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
# 更新进度
b_unit.update(1)
return resp
text = pdf2text(self.file_path)

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 249 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

BIN
img/chatchat-qrcode.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 326 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 153 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 249 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 137 KiB

BIN
img/webui_0915_0.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 237 KiB

BIN
img/webui_0915_1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

View File

@ -1,8 +1,7 @@
langchain==0.0.266
langchain==0.0.287
fschat[model_worker]==0.2.28
openai
zhipuai
sentence_transformers
fschat==0.2.24
transformers>=4.31.0
torch~=2.0.0
fastapi~=0.99.1
@ -19,6 +18,12 @@ spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.2
requests
pathlib
pytest
scikit-learn
numexpr
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2
@ -33,3 +38,5 @@ streamlit-chatbox>=1.1.6
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog
tqdm
websockets

View File

@ -1,7 +1,7 @@
langchain==0.0.266
langchain==0.0.287
fschat[model_worker]==0.2.28
openai
sentence_transformers
fschat==0.2.24
transformers>=4.31.0
torch~=2.0.0
fastapi~=0.99.1
@ -9,17 +9,22 @@ nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
pydantic~=1.10.11
unstructured[all-docs]
unstructured[all-docs]>=0.10.4
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
faiss-cpu
nltk
accelerate
spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.1
rapidocr_onnxruntime>=1.3.2
requests
pathlib
pytest
scikit-learn
numexpr
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2
# pgvector
# pgvector

View File

@ -7,4 +7,5 @@ streamlit-chatbox>=1.1.6
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
nltk
watchdog
watchdog
websockets

View File

@ -4,22 +4,21 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
from configs import VERSION
from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from fastapi import Body
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store,
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
import httpx
from server.llm_api import list_llm_models, change_llm_model, stop_llm_model
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -98,23 +97,23 @@ def create_app():
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_doc",
app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库"
)(upload_doc)
summary="上传文件到知识库,并/或进行向量化"
)(upload_docs)
app.post("/knowledge_base/delete_doc",
app.post("/knowledge_base/delete_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_doc)
)(delete_docs)
app.post("/knowledge_base/update_doc",
app.post("/knowledge_base/update_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新现有文件到知识库"
)(update_doc)
)(update_docs)
app.get("/knowledge_base/download_doc",
tags=["Knowledge Base Management"],
@ -126,73 +125,20 @@ def create_app():
)(recreate_vector_store)
# LLM模型相关接口
@app.post("/llm_model/list_models",
app.post("/llm_model/list_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型")
def list_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
summary="列出当前已加载的模型",
)(list_llm_models)
@app.post("/llm_model/stop",
app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
)(stop_llm_model)
@app.post("/llm_model/change",
app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
)(change_llm_model)
return app

View File

@ -1,6 +1,6 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import llm_model_dict, LLM_MODEL
from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE
from server.chat.utils import wrap_done
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
@ -12,15 +12,17 @@ from typing import List
from server.chat.utils import History
def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
):
history = [History.from_data(h) for h in history]
@ -37,6 +39,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)

View File

@ -1,7 +1,8 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
TEMPERATURE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
@ -18,11 +19,11 @@ from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
history: List[History] = Body([],
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
@ -30,10 +31,11 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
@ -48,6 +50,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
model_name: str = LLM_MODEL,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
@ -55,6 +58,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
@ -86,9 +90,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():

View File

@ -1,7 +1,7 @@
from fastapi.responses import StreamingResponse
from typing import List
import openai
from configs.model_config import llm_model_dict, LLM_MODEL, logger
from configs.model_config import llm_model_dict, LLM_MODEL, logger, log_verbose
from pydantic import BaseModel
@ -29,13 +29,13 @@ async def openai_chat(msg: OpenAiChatMsgIn):
print(f"{openai.api_base=}")
print(msg)
def get_response(msg):
async def get_response(msg):
data = msg.dict()
try:
response = openai.ChatCompletion.create(**data)
response = await openai.ChatCompletion.acreate(**data)
if msg.stream:
for data in response:
async for data in response:
if choices := data.choices:
if chunk := choices[0].get("delta", {}).get("content"):
print(chunk, end="", flush=True)
@ -46,8 +46,9 @@ async def openai_chat(msg: OpenAiChatMsgIn):
print(answer)
yield(answer)
except Exception as e:
print(type(e))
logger.error(e)
msg = f"获取ChatCompletion时出错{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return StreamingResponse(
get_response(msg),

View File

@ -2,7 +2,9 @@ from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from fastapi.concurrency import run_in_threadpool
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K,
PROMPT_TEMPLATE, TEMPERATURE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
@ -47,29 +49,31 @@ def search_result2docs(search_results):
return docs
def lookup_search_engine(
async def lookup_search_engine(
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
):
results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k)
search_engine = SEARCH_ENGINES[search_engine_name]
results = await run_in_threadpool(search_engine, query, result_len=top_k)
docs = search_result2docs(results)
return docs
def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -93,10 +97,11 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = lookup_search_engine(query, search_engine_name, top_k)
docs = await lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
@ -119,9 +124,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token,
"docs": source_documents},
ensure_ascii=False)
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():

View File

@ -2,6 +2,7 @@ import asyncio
from typing import Awaitable, List, Tuple, Dict, Union
from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -10,7 +11,9 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
await fn
except Exception as e:
# TODO: handle exception
print(f"Caught exception: {e}")
msg = f"Caught exception: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
finally:
# Signal the aiter to stop.
event.set()

View File

@ -3,19 +3,19 @@ from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_base_repository import list_kbs_from_db
from configs.model_config import EMBEDDING_MODEL
from configs.model_config import EMBEDDING_MODEL, logger, log_verbose
from fastapi import Body
async def list_kbs():
def list_kbs():
# Get List of Knowledge Base
return ListResponse(data=list_kbs_from_db())
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse:
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -30,14 +30,16 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
try:
kb.create_kb()
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"创建知识库出错: {e}")
msg = f"创建知识库出错: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
@ -55,7 +57,9 @@ async def delete_kb(
if status:
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}")
msg = f"删除知识库时出现意外: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

View File

@ -0,0 +1,137 @@
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
import threading
from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE,
embedding_model_dict, logger, log_verbose)
from server.utils import embedding_device
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
class ThreadSafeObject:
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
self._obj = obj
self._key = key
self._pool = pool
self._lock = threading.RLock()
self._loaded = threading.Event()
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}>"
@contextmanager
def acquire(self, owner: str = "", msg: str = ""):
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self._key)
if log_verbose:
logger.info(f"{owner} 开始操作:{self._key}{msg}")
yield self._obj
finally:
if log_verbose:
logger.info(f"{owner} 结束操作:{self._key}{msg}")
self._lock.release()
def start_loading(self):
self._loaded.clear()
def finish_loading(self):
self._loaded.set()
def wait_for_loading(self):
self._loaded.wait()
@property
def obj(self):
return self._obj
@obj.setter
def obj(self, val: Any):
self._obj = val
class CachePool:
def __init__(self, cache_num: int = -1):
self._cache_num = cache_num
self._cache = OrderedDict()
self.atomic = threading.RLock()
def keys(self) -> List[str]:
return list(self._cache.keys())
def _check_count(self):
if isinstance(self._cache_num, int) and self._cache_num > 0:
while len(self._cache) > self._cache_num:
self._cache.popitem(last=False)
def get(self, key: str) -> ThreadSafeObject:
if cache := self._cache.get(key):
cache.wait_for_loading()
return cache
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
self._cache[key] = obj
self._check_count()
return obj
def pop(self, key: str = None) -> ThreadSafeObject:
if key is None:
return self._cache.popitem(last=False)
else:
return self._cache.pop(key, None)
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
cache = self.get(key)
if cache is None:
raise RuntimeError(f"请求的资源 {key} 不存在")
elif isinstance(cache, ThreadSafeObject):
self._cache.move_to_end(key)
return cache.acquire(owner=owner, msg=msg)
else:
return cache
def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings:
from server.db.repository.knowledge_base_repository import get_kb_detail
kb_detail = get_kb_detail(kb_name=kb_name)
print(kb_detail)
embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL)
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str, device: str) -> Embeddings:
self.atomic.acquire()
model = model or EMBEDDING_MODEL
device = device or embedding_device()
key = (model, device)
if not self.get(key):
item = ThreadSafeObject(key, pool=self)
self.set(key, item)
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()
else:
self.atomic.release()
return self.get(key).obj
embeddings_pool = EmbeddingsPool(cache_num=1)

View File

@ -0,0 +1,157 @@
from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores import FAISS
import os
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
def save(self, path: str, create_path: bool = True):
with self.acquire():
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self._key} 保存到磁盘")
return ret
def clear(self):
ret = []
with self.acquire():
ids = list(self._obj.docstore._dict.keys())
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self._key} 清空")
return ret
class _FaissPool(CachePool):
def new_vector_store(
self,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS:
embeddings = embeddings_pool.load_embeddings(embed_model, embed_device)
# create an empty vector store
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
def save_vector_store(self, kb_name: str, path: str=None):
if cache := self.get(kb_name):
return cache.save(path)
def unload_vector_store(self, kb_name: str):
if cache := self.get(kb_name):
self.pop(kb_name)
logger.info(f"成功释放向量库:{kb_name}")
class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' from disk.")
vs_path = get_vs_path(kb_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
class MemoFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' to memory.")
# create an empty vector store
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
memo_faiss_pool = MemoFaissPool()
if __name__ == "__main__":
import time, random
from pprint import pprint
kb_names = ["vs1", "vs2", "vs3"]
# for name in kb_names:
# memo_faiss_pool.load_vector_store(name)
def worker(vs_name: str, name: str):
vs_name = "samples"
time.sleep(random.randint(1, 5))
embeddings = embeddings_pool.load_embeddings()
r = random.randint(1, 3)
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
if r == 1: # add docs
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
pprint(ids)
elif r == 2: # search docs
docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0)
pprint(docs)
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear()
threads = []
for n in range(1, 30):
t = threading.Thread(target=worker,
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
daemon=True)
t.start()
threads.append(t)
for t in threads:
t.join()

View File

@ -1,12 +1,18 @@
import os
import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose,)
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
files2docs_in_thread, KnowledgeFile)
from fastapi.responses import StreamingResponse, FileResponse
from pydantic import Json
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
from typing import List, Dict
from langchain.docstore.document import Document
@ -29,7 +35,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
return data
async def list_files(
def list_files(
knowledge_base_name: str
) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
@ -44,11 +50,87 @@ async def list_files(
return ListResponse(data=all_doc_names)
async def upload_doc(file: UploadFile = File(..., description="上传文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]),
override: bool = Form(False, description="覆盖已有文件"),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
def _save_files_in_thread(files: List[UploadFile],
knowledge_base_name: str,
override: bool):
'''
通过多线程将上传的文件保存到对应知识库目录内
生成器返回保存结果{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
'''
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
'''
保存单个文件
'''
try:
filename = file.filename
file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
file_content = file.file.read() # 读取上传文件的内容
if (os.path.isfile(file_path)
and not override
and os.path.getsize(file_path) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {filename} 已存在。"
logger.warn(file_status)
return dict(code=404, msg=file_status, data=data)
with open(file_path, "wb") as f:
f.write(file_content)
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
except Exception as e:
msg = f"{filename} 文件上传失败,报错信息为: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return dict(code=500, msg=msg, data=data)
params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]
for result in run_in_thread_pool(save_file, params=params):
yield result
# 似乎没有单独增加一个文件上传API接口的必要
# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件")):
# '''
# API接口上传文件。流式返回保存结果{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
# '''
# def generate(files, knowledge_base_name, override):
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
# yield json.dumps(result, ensure_ascii=False)
# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream")
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件"),
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
# def save_files(files, knowledge_base_name, override):
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
# yield json.dumps(result, ensure_ascii=False)
# def files_to_docs(files):
# for result in files2docs_in_thread(files):
# yield json.dumps(result, ensure_ascii=False)
def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
API接口上传文件/或向量化
'''
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -56,40 +138,42 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"),
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
file_content = await file.read() # 读取上传文件的内容
failed_files = {}
file_names = list(docs.keys())
try:
kb_file = KnowledgeFile(filename=file.filename,
knowledge_base_name=knowledge_base_name)
# 先将上传的文件保存到磁盘
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
filename = result["data"]["file_name"]
if result["code"] != 200:
failed_files[filename] = result["msg"]
if (os.path.exists(kb_file.filepath)
and not override
and os.path.getsize(kb_file.filepath) == len(file_content)
):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status)
if filename not in file_names:
file_names.append(filename)
with open(kb_file.filepath, "wb") as f:
f.write(file_content)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}")
# 对保存的文件进行向量化
if to_vector_store:
result = update_docs(
knowledge_base_name=knowledge_base_name,
file_names=file_names,
override_custom_docs=True,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance,
docs=docs,
not_refresh_vs_cache=True,
)
failed_files.update(result.data["failed_files"])
if not not_refresh_vs_cache:
kb.save_vector_store()
try:
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}")
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name.md"]),
delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
@ -98,24 +182,36 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
if not kb.exist_doc(doc_name):
return BaseResponse(code=404, msg=f"未找到文件 {doc_name}")
failed_files = {}
for file_name in file_names:
if not kb.exist_doc(file_name):
failed_files[file_name] = f"未找到文件 {file_name}"
try:
kb_file = KnowledgeFile(filename=doc_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
except Exception as e:
msg = f"{file_name} 文件删除失败,错误信息:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
failed_files[file_name] = msg
return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功")
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
async def update_doc(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
更新知识库文档
@ -127,22 +223,62 @@ async def update_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if os.path.exists(kb_file.filepath):
kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}")
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}")
failed_files = {}
kb_files = []
return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败")
# 生成需要加载docs的文件列表
for file_name in file_names:
file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name)
# 如果该文件之前使用了自定义docs则根据参数决定略过或覆盖
if file_detail.get("custom_docs") and not override_custom_docs:
continue
if file_name not in docs:
try:
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
except Exception as e:
msg = f"加载文档 {file_name} 时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
failed_files[file_name] = msg
# 从文件生成docs并进行向量化。
# 这里利用了KnowledgeFile的缓存功能在多线程中加载Document然后传给KnowledgeFile
for status, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if status:
kb_name, file_name, new_docs = result
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
kb_file.splited_docs = new_docs
kb.update_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
failed_files[file_name] = error
# 将自定义的docs进行向量化
for file_name, v in docs.items():
try:
v = [x if isinstance(x, Document) else Document(**x) for x in v]
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
except Exception as e:
msg = f"{file_name} 添加自定义docs时出错{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
failed_files[file_name] = msg
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
async def download_doc(
knowledge_base_name: str = Query(..., examples=["samples"]),
file_name: str = Query(..., examples=["test.txt"]),
def download_doc(
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
):
'''
下载知识库文档
@ -154,6 +290,11 @@ async def download_doc(
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
if preview:
content_disposition_type = "inline"
else:
content_disposition_type = None
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
@ -162,20 +303,27 @@ async def download_doc(
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data")
media_type="multipart/form-data",
content_disposition_type=content_disposition_type,
)
except Exception as e:
print(e)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}")
msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")
async def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
):
def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
):
'''
recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network.
@ -190,27 +338,33 @@ async def recreate_vector_store(
else:
kb.create_kb()
kb.clear_vs()
docs = list_files_from_folder(knowledge_base_name)
for i, doc in enumerate(docs):
try:
kb_file = KnowledgeFile(doc, knowledge_base_name)
files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files]
i = 0
for status, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if status:
kb_name, file_name, docs = result
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
kb_file.splited_docs = docs
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(docs)}): {doc}",
"total": len(docs),
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i,
"doc": doc,
"doc": file_name,
}, ensure_ascii=False)
if i == len(docs) - 1:
not_refresh_vs_cache = False
else:
not_refresh_vs_cache = True
kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache)
except Exception as e:
print(e)
kb.add_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。",
"msg": msg,
})
i += 1
return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -25,7 +25,6 @@ from server.knowledge_base.utils import (
list_kbs_from_folder, list_files_from_folder,
)
from server.utils import embedding_device
from typing import List, Union, Dict
from typing import List, Union, Dict, Optional
@ -51,6 +50,12 @@ class KBService(ABC):
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def save_vector_store(self):
'''
保存向量库:FAISS保存到磁盘milvus保存到数据库PGVector暂未支持
'''
pass
def create_kb(self):
"""
创建知识库
@ -84,6 +89,8 @@ class KBService(ABC):
"""
if docs:
custom_docs = True
for doc in docs:
doc.metadata.setdefault("source", kb_file.filepath)
else:
docs = kb_file.file2text()
custom_docs = False
@ -137,7 +144,6 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs
# TODO: milvus/pg需要实现该方法
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None

View File

@ -3,61 +3,16 @@ import shutil
from configs.model_config import (
KB_ROOT_PATH,
CACHED_VS_NUM,
EMBEDDING_MODEL,
SCORE_THRESHOLD
SCORE_THRESHOLD,
logger, log_verbose,
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile
from langchain.embeddings.base import Embeddings
from typing import List, Dict, Optional
from langchain.docstore.document import Document
from server.utils import torch_gc, embedding_device
_VECTOR_STORE_TICKS = {}
@lru_cache(CACHED_VS_NUM)
def load_faiss_vector_store(
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
if not os.path.exists(vs_path):
os.makedirs(vs_path)
if "index.faiss" in os.listdir(vs_path):
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
else:
# create an empty vector store
doc = Document(page_content="init", metadata={})
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = [k for k, v in search_index.docstore._dict.items()]
search_index.delete(ids)
search_index.save_local(vs_path)
if tick == 0: # vector store is loaded first time
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
return search_index
def refresh_vs_cache(kb_name: str):
"""
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
from server.utils import torch_gc
class FaissKBService(KBService):
@ -73,24 +28,15 @@ class FaissKBService(KBService):
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> FAISS:
return load_faiss_vector_store(
knowledge_base_name=self.kb_name,
embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
def save_vector_store(self, vector_store: FAISS = None):
vector_store = vector_store or self.load_vector_store()
vector_store.save_local(self.vs_path)
return vector_store
def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name)
def save_vector_store(self):
self.load_vector_store().save(self.vs_path)
def get_doc_by_id(self, id: str) -> Optional[Document]:
vector_store = self.load_vector_store()
return vector_store.docstore._dict.get(id)
with self.load_vector_store().acquire() as vs:
return vs.docstore._dict.get(id)
def do_init(self):
self.kb_path = self.get_kb_path()
@ -111,43 +57,38 @@ class FaissKBService(KBService):
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
search_index = self.load_vector_store()
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
vector_store = self.load_vector_store()
ids = vector_store.add_documents(docs)
with self.load_vector_store().acquire() as vs:
ids = vs.add_documents(docs)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return doc_infos
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
vector_store.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return vector_store
with self.load_vector_store().acquire() as vs:
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
if len(ids) > 0:
vs.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
return ids
def do_clear_vs(self):
with kb_faiss_pool.atomic:
kb_faiss_pool.pop(self.kb_name)
shutil.rmtree(self.vs_path)
os.makedirs(self.vs_path)
self.refresh_vs_cache()
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
@ -165,4 +106,4 @@ if __name__ == '__main__':
faissService.add_doc(KnowledgeFile("README.md", "test"))
faissService.delete_doc(KnowledgeFile("README.md", "test"))
faissService.do_drop_kb()
print(faissService.search_docs("如何启动api服务"))
print(faissService.search_docs("如何启动api服务"))

View File

@ -22,6 +22,10 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
def save_vector_store(self):
if self.milvus.col:
self.milvus.col.flush()
def get_doc_by_id(self, id: str) -> Optional[Document]:
if self.milvus.col:
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
@ -56,6 +60,7 @@ class MilvusKBService(KBService):
def do_drop_kb(self):
if self.milvus.col:
self.milvus.col.release()
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
@ -63,6 +68,15 @@ class MilvusKBService(KBService):
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
# TODO: workaround for bug #10492 in langchain
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.milvus.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.milvus._text_field, None)
doc.metadata.pop(self.milvus._vector_field, None)
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
@ -76,7 +90,8 @@ class MilvusKBService(KBService):
def do_clear_vs(self):
if self.milvus.col:
self.milvus.col.drop()
self.do_drop_kb()
self.do_init()
if __name__ == '__main__':

View File

@ -26,6 +26,7 @@ class PGKBService(KBService):
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_id(self, id: str) -> Optional[Document]:
with self.pg_vector.connect() as connect:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id")
@ -77,6 +78,7 @@ class PGKBService(KBService):
def do_clear_vs(self):
self.pg_vector.delete_collection()
self.pg_vector.create_collection()
if __name__ == '__main__':

View File

@ -1,19 +1,15 @@
from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE
from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
logger, log_verbose)
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder, run_in_thread_pool,
files2docs_in_thread,
list_files_from_folder,files2docs_in_thread,
KnowledgeFile,)
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.base import Base, engine
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Any, List
pool = ThreadPoolExecutor(os.cpu_count())
def create_tables():
Base.metadata.create_all(bind=engine)
@ -30,7 +26,9 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name)
kb_files.append(kb_file)
except Exception as e:
print(f"{e},已跳过")
msg = f"{e},已跳过"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return kb_files
@ -39,6 +37,9 @@ def folder2db(
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size: int = -1,
chunk_overlap: int = -1,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
):
'''
use existed files in local folder to populate database and/or vector store.
@ -59,7 +60,10 @@ def folder2db(
print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。")
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
for success, result in files2docs_in_thread(kb_files, pool=pool):
for success, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if success:
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
@ -67,10 +71,7 @@ def folder2db(
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
kb.save_vector_store()
elif mode == "fill_info_only":
files = list_files_from_folder(kb_name)
kb_files = file_to_kbfile(kb_name, files)
@ -84,17 +85,17 @@ def folder2db(
for kb_file in kb_files:
kb.update_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
kb.save_vector_store()
elif mode == "increament":
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
for success, result in files2docs_in_thread(kb_files, pool=pool):
for success, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if success:
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库")
@ -102,10 +103,7 @@ def folder2db(
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
kb.save_vector_store()
else:
print(f"unspported migrate mode: {mode}")
@ -135,9 +133,7 @@ def prune_db_files(kb_name: str):
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
kb.save_vector_store()
return kb_files
def prune_folder_files(kb_name: str):

View File

@ -1,39 +1,33 @@
import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
from transformers import AutoTokenizer
from configs.model_config import (
embedding_model_dict,
EMBEDDING_MODEL,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE
ZH_TITLE_ENHANCE,
logger,
log_verbose,
text_splitter_dict,
llm_model_dict,
LLM_MODEL,
TEXT_SPLITTER
)
from functools import lru_cache
import importlib
from text_splitter import zh_title_enhance
import langchain.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool, embedding_device
import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
if isinstance(self, HuggingFaceEmbeddings):
return hash(self.model_name)
elif isinstance(self, HuggingFaceBgeEmbeddings):
return hash(self.model_name)
elif isinstance(self, OpenAIEmbeddings):
return hash(self.model)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
OpenAIEmbeddings.__hash__ = _embeddings_hash
HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
@ -68,19 +62,12 @@ def list_files_from_folder(kb_name: str):
if os.path.isfile(os.path.join(doc_path, file))]
@lru_cache(1)
def load_embeddings(model: str, device: str):
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
return embeddings
def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()):
'''
从缓存中加载embeddings可以避免多线程时竞争加载
'''
from server.knowledge_base.kb_cache.base import embeddings_pool
return embeddings_pool.load_embeddings(model=model, device=device)
LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
@ -99,16 +86,16 @@ SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
class CustomJSONLoader(langchain.document_loaders.JSONLoader):
'''
langchain的JSONLoader需要jq在win上使用不便进行替代
langchain的JSONLoader需要jq在win上使用不便进行替代针对langchain==0.0.286
'''
def __init__(
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
self,
file_path: Union[str, Path],
content_key: Optional[str] = None,
metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None,
text_content: bool = True,
json_lines: bool = False,
):
"""Initialize the JSONLoader.
@ -130,21 +117,6 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
self._text_content = text_content
self._json_lines = json_lines
# TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows.
# This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded.
def load(self) -> List[Document]:
"""Load and return documents from the JSON file."""
docs: List[Document] = []
if self._json_lines:
with self.file_path.open(encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
self._parse(line, docs)
else:
self._parse(self.file_path.read_text(encoding="utf-8"), docs)
return docs
def _parse(self, content: str, docs: List[Document]) -> None:
"""Convert given content to documents."""
data = json.loads(content)
@ -154,13 +126,14 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader):
# and prevent the user from getting a cryptic error later on.
if self._content_key is not None:
self._validate_content_key(data)
if self._metadata_func is not None:
self._validate_metadata_func(data)
for i, sample in enumerate(data, len(docs) + 1):
metadata = dict(
source=str(self.file_path),
seq_num=i,
text = self._get_text(sample=sample)
metadata = self._get_metadata(
sample=sample, source=str(self.file_path), seq_num=i
)
text = self._get_text(sample=sample, metadata=metadata)
docs.append(Document(page_content=text, metadata=metadata))
@ -173,12 +146,124 @@ def get_LoaderClass(file_extension):
return LoaderClass
# 把一些向量化共用逻辑从KnowledgeFile抽取出来等langchain支持内存文件的时候可以将非磁盘文件向量化
def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]):
'''
根据loader_name和文件路径或内容返回文档加载器
'''
try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e:
msg = f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(file_path_or_content, autodetect_encoding=True)
elif loader_name == "CSVLoader":
loader = DocumentLoader(file_path_or_content, encoding="utf-8")
elif loader_name == "JSONLoader":
loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False)
elif loader_name == "CustomJSONLoader":
loader = DocumentLoader(file_path_or_content, text_content=False)
elif loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
elif loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(file_path_or_content, mode="elements")
else:
loader = DocumentLoader(file_path_or_content)
return loader
def make_text_splitter(
splitter_name: str = TEXT_SPLITTER,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
):
"""
根据参数获取特定的分词器
"""
splitter_name = splitter_name or "SpacyTextSplitter"
try:
if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on)
else:
try: ## 优先使用用户自定义的text_splitter
text_splitter_module = importlib.import_module('text_splitter')
TextSplitter = getattr(text_splitter_module, splitter_name)
except: ## 否则使用langchain的text_splitter
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, splitter_name)
if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载
try:
text_splitter = TextSplitter.from_tiktoken_encoder(
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except:
text_splitter = TextSplitter.from_tiktoken_encoder(
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
llm_model_dict[LLM_MODEL]["local_model_path"]
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
from transformers import GPT2TokenizerFast
from langchain.text_splitter import CharacterTextSplitter
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
else: ## 字符长度加载
tokenizer = AutoTokenizer.from_pretrained(
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
trust_remote_code=True)
text_splitter = TextSplitter.from_huggingface_tokenizer(
tokenizer=tokenizer,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
else:
try:
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except:
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except Exception as e:
print(e)
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
return text_splitter
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str
):
'''
对应知识库目录中的文件必须是磁盘上存在的才能进行向量化等操作
'''
self.kb_name = knowledge_base_name
self.filename = filename
self.ext = os.path.splitext(filename)[-1].lower()
@ -186,75 +271,67 @@ class KnowledgeFile:
raise ValueError(f"暂未支持的文件格式 {self.ext}")
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext)
self.text_splitter_name = TEXT_SPLITTER
# TODO: 增加依据文件格式匹配text_splitter
self.text_splitter_name = None
def file2docs(self, refresh: bool=False):
if self.docs is None or refresh:
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(self.document_loader_name, self.filepath)
self.docs = loader.load()
return self.docs
def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False):
if self.docs is not None and not refresh:
return self.docs
print(f"{self.document_loader_name} used for {self.filepath}")
try:
if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
def docs2texts(
self,
docs: List[Document] = None,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
docs = docs or self.file2docs(refresh=refresh)
if not docs:
return []
if self.ext not in [".csv"]:
if text_splitter is None:
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
docs = text_splitter.split_text(docs[0].page_content)
for doc in docs:
# 如果文档有元数据
if doc.metadata:
doc.metadata["source"] = os.path.basename(self.filepath)
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, self.document_loader_name)
except Exception as e:
print(e)
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if self.document_loader_name == "UnstructuredFileLoader":
loader = DocumentLoader(self.filepath, autodetect_encoding=True)
elif self.document_loader_name == "CSVLoader":
loader = DocumentLoader(self.filepath, encoding="utf-8")
elif self.document_loader_name == "JSONLoader":
loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False)
elif self.document_loader_name == "CustomJSONLoader":
loader = DocumentLoader(self.filepath, text_content=False)
elif self.document_loader_name == "UnstructuredMarkdownLoader":
loader = DocumentLoader(self.filepath, mode="elements")
elif self.document_loader_name == "UnstructuredHTMLLoader":
loader = DocumentLoader(self.filepath, mode="elements")
else:
loader = DocumentLoader(self.filepath)
docs = text_splitter.split_documents(docs)
if self.ext in ".csv":
docs = loader.load()
else:
try:
if self.text_splitter_name is None:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter")
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
self.text_splitter_name = "SpacyTextSplitter"
else:
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, self.text_splitter_name)
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE)
except Exception as e:
print(e)
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
)
docs = loader.load_and_split(text_splitter)
print(docs[0])
if using_zh_title_enhance:
print(f"文档切分示例:{docs[0]}")
if zh_title_enhance:
docs = zh_title_enhance(docs)
self.docs = docs
return docs
self.splited_docs = docs
return self.splited_docs
def file2text(
self,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
if self.splited_docs is None or refresh:
docs = self.file2docs()
self.splited_docs = self.docs2texts(docs=docs,
zh_title_enhance=zh_title_enhance,
refresh=refresh,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
text_splitter=text_splitter)
return self.splited_docs
def file_exist(self):
return os.path.isfile(self.filepath)
def get_mtime(self):
return os.path.getmtime(self.filepath)
@ -263,53 +340,54 @@ class KnowledgeFile:
return os.path.getsize(self.filepath)
def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
在线程池中批量运行任务并将运行结果以生成器的形式返回
请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数
'''
tasks = []
if pool is None:
pool = ThreadPoolExecutor()
for kwargs in params:
thread = pool.submit(func, **kwargs)
tasks.append(thread)
for obj in as_completed(tasks):
yield obj.result()
def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
pool: ThreadPoolExecutor = None,
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
利用多线程批量将文件转化成langchain Document.
生成器返回值为{(kb_name, file_name): docs}
利用多线程批量将磁盘文件转化成langchain Document.
如果传入参数是Tuple形式为(filename, kb_name)
生成器返回值为 status, (kb_name, file_name, docs | error)
'''
def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]:
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
return False, e
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return False, (file.kb_name, file.filename, msg)
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
if isinstance(file, tuple) and len(file) >= 2:
files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs = file
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs["file"] = file
kwargs["chunk_size"] = chunk_size
kwargs["chunk_overlap"] = chunk_overlap
kwargs["zh_title_enhance"] = zh_title_enhance
kwargs_list.append(kwargs)
for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool):
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
yield result
if __name__ == "__main__":
from pprint import pprint
kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples")
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
docs = kb_file.file2docs()
pprint(docs[-1])
docs = kb_file.file2text()
pprint(docs[-1])

View File

@ -1,279 +1,71 @@
from multiprocessing import Process, Queue
import multiprocessing as mp
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import BaseResponse, fschat_controller_address
import httpx
host_ip = "0.0.0.0"
controller_port = 20001
model_worker_port = 20002
openai_api_port = 8888
base_url = "http://127.0.0.1:{}"
def list_llm_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
placeholder: str = Body(None, description="该参数未使用,占位用"),
) -> BaseResponse:
'''
从fastchat controller获取已加载模型列表
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
data=[],
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
def create_controller_app(
dispatch_method="shortest_queue",
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
) -> BaseResponse:
'''
向fastchat controller请求停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
def change_llm_model(
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
):
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.controller import app, Controller
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
MakeFastAPIOffline(app)
app.title = "FastChat Controller"
return app
def create_model_worker_app(
worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
device=llm_device(),
gpus=None,
max_gpu_memory="20GiB",
load_8bit=False,
cpu_offloading=None,
gptq_ckpt=None,
gptq_wbits=16,
gptq_groupsize=-1,
gptq_act_order=False,
awq_ckpt=None,
awq_wbits=16,
awq_groupsize=-1,
model_names=[LLM_MODEL],
num_gpus=1, # not in fastchat
conv_template=None,
limit_worker_concurrency=5,
stream_interval=2,
no_register=False,
):
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
import argparse
import threading
import fastchat.serve.model_worker
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
'''
向fastchat controller请求切换LLM模型
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
self.heart_beat_thread.start()
ModelWorker.init_heart_beat = _new_init_heart_beat
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.model_path = model_path
args.model_names = model_names
args.device = device
args.load_8bit = load_8bit
args.gptq_ckpt = gptq_ckpt
args.gptq_wbits = gptq_wbits
args.gptq_groupsize = gptq_groupsize
args.gptq_act_order = gptq_act_order
args.awq_ckpt = awq_ckpt
args.awq_wbits = awq_wbits
args.awq_groupsize = awq_groupsize
args.gpus = gpus
args.num_gpus = num_gpus
args.max_gpu_memory = max_gpu_memory
args.cpu_offloading = cpu_offloading
args.worker_address = worker_address
args.controller_address = controller_address
args.conv_template = conv_template
args.limit_worker_concurrency = limit_worker_concurrency
args.stream_interval = stream_interval
args.no_register = no_register
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if gpus and num_gpus is None:
num_gpus = len(gpus.split(','))
args.num_gpus = num_gpus
gptq_config = GptqConfig(
ckpt=gptq_ckpt or model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
# torch.multiprocessing.set_start_method('spawn')
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({LLM_MODEL})"
return app
def create_openai_api_app(
controller_address=base_url.format(controller_port),
api_keys=[],
):
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
MakeFastAPIOffline(app)
app.title = "FastChat OpeanAI API Server"
return app
def run_controller(q):
import uvicorn
app = create_controller_app()
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
q.put(1)
uvicorn.run(app, host=host_ip, port=controller_port)
def run_model_worker(q, *args, **kwargs):
import uvicorn
app = create_model_worker_app(*args, **kwargs)
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != 1:
q.put(no)
else:
break
q.put(2)
uvicorn.run(app, host=host_ip, port=model_worker_port)
def run_openai_api(q):
import uvicorn
app = create_openai_api_app()
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != 2:
q.put(no)
else:
break
q.put(3)
uvicorn.run(app, host=host_ip, port=openai_api_port)
if __name__ == "__main__":
mp.set_start_method("spawn")
queue = Queue()
logger.info(llm_model_dict[LLM_MODEL])
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
if not model_path:
logger.error("local_model_path 不能为空")
else:
controller_process = Process(
target=run_controller,
name=f"controller({os.getpid()})",
args=(queue,),
daemon=True,
)
controller_process.start()
model_worker_process = Process(
target=run_model_worker,
name=f"model_worker({os.getpid()})",
args=(queue,),
# kwargs={"load_8bit": True},
daemon=True,
)
model_worker_process.start()
openai_api_process = Process(
target=run_openai_api,
name=f"openai_api({os.getpid()})",
args=(queue,),
daemon=True,
)
openai_api_process.start()
try:
model_worker_process.join()
controller_process.join()
openai_api_process.join()
except KeyboardInterrupt:
model_worker_process.terminate()
controller_process.terminate()
openai_api_process.terminate()
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1"
# model = "chatglm2-6b"
# # create a chat completion
# completion = openai.ChatCompletion.create(
# model=model,
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
# )
# # print the completion
# print(completion.choices[0].message.content)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
return BaseResponse(
code=500,
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")

View File

@ -0,0 +1,79 @@
import base64
import datetime
import hashlib
import hmac
from urllib.parse import urlparse
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
class Ws_Param(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, Spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(Spark_url).netloc
self.path = urlparse(Spark_url).path
self.Spark_url = Spark_url
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数生成url
url = self.Spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
return url
def gen_params(appid, domain,question, temperature):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default",
"temperature": temperature,
}
},
"payload": {
"message": {
"text": question
}
}
}
return data

View File

@ -1 +1,4 @@
from .zhipu import ChatGLMWorker
from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker
from .qianfan import QianFanWorker

View File

@ -69,3 +69,28 @@ class ApiModelWorker(BaseModelWorker):
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
# help methods
def get_config(self):
from server.utils import get_model_worker_config
return get_model_worker_config(self.model_names[0])
def prompt_to_messages(self, prompt: str) -> List[Dict]:
'''
将prompt字符串拆分成messages.
'''
result = []
user_role = self.conv.roles[0]
ai_role = self.conv.roles[1]
user_start = user_role + ":"
ai_start = ai_role + ":"
for msg in prompt.split(self.conv.sep)[1:-1]:
if msg.startswith(user_start):
if content := msg[len(user_start):].strip():
result.append({"role": user_role, "content": content})
elif msg.startswith(ai_start):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknow role in msg: {msg}")
return result

View File

@ -0,0 +1,100 @@
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
import httpx
from pprint import pprint
from typing import List, Dict
class MiniMaxWorker(ApiModelWorker):
BASE_URL = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
def __init__(
self,
*,
model_names: List[str] = ["minimax-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["USER", "BOT"],
sep="\n### ",
stop_str="###",
)
def prompt_to_messages(self, prompt: str) -> List[Dict]:
result = super().prompt_to_messages(prompt)
messages = [{"sender_type": x["role"], "text": x["content"]} for x in result]
return messages
def generate_stream_gate(self, params):
# 按照官网推荐直接调用abab 5.5模型
# TODO: 支持指定回复要求支持指定用户名称、AI名称
super().generate_stream_gate(params)
config = self.get_config()
group_id = config.get("group_id")
api_key = config.get("api_key")
pro = "_pro" if config.get("is_pro") else ""
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
data = {
"model": "abab5.5-chat",
"stream": True,
"tokens_to_generate": 1024, # TODO: 1024为官网默认值
"mask_sensitive_info": True,
"messages": self.prompt_to_messages(params["prompt"]),
"temperature": params.get("temperature"),
"top_p": params.get("top_p"),
"bot_setting": [],
}
print("request data sent to minimax:")
pprint(data)
response = httpx.stream("POST",
self.BASE_URL.format(pro=pro, group_id=group_id),
headers=headers,
json=data)
with response as r:
text = ""
for e in r.iter_text():
if e.startswith("data: "): # 真是优秀的返回
data = json.loads(e[6:])
if not data.get("usage"):
if choices := data.get("choices"):
chunk = choices[0].get("delta", "").strip()
if chunk:
print(chunk)
text += chunk
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = MiniMaxWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20004",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20003)

View File

@ -0,0 +1,172 @@
from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv
import sys
import json
import httpx
from cachetools import cached, TTLCache
from server.utils import get_model_worker_config
from typing import List, Literal, Dict
MODEL_VERSIONS = {
"ernie-bot": "completions",
"ernie-bot-turbo": "eb-instant",
"bloomz-7b": "bloomz_7b1",
"qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed",
"llama2-7b-chat": "llama_2_7b",
"llama2-13b-chat": "llama_2_13b",
"llama2-70b-chat": "llama_2_70b",
"qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b",
"chatglm2-6b-32k": "chatglm2_6b_32k",
"aquilachat-7b": "aquilachat_7b",
# "linly-llama2-ch-7b": "", # 暂未发布
# "linly-llama2-ch-13b": "", # 暂未发布
# "chatglm2-6b": "", # 暂未发布
# "chatglm2-6b-int4": "", # 暂未发布
# "falcon-7b": "", # 暂未发布
# "falcon-180b-chat": "", # 暂未发布
# "falcon-40b": "", # 暂未发布
# "rwkv4-world": "", # 暂未发布
# "rwkv5-world": "", # 暂未发布
# "rwkv4-pile-14b": "", # 暂未发布
# "rwkv4-raven-14b": "", # 暂未发布
# "open-llama-7b": "", # 暂未发布
# "dolly-12b": "", # 暂未发布
# "mpt-7b-instruct": "", # 暂未发布
# "mpt-30b-instruct": "", # 暂未发布
# "OA-Pythia-12B-SFT-4": "", # 暂未发布
# "xverse-13b": "", # 暂未发布
# # 以下为企业测试,需要单独申请
# "flan-ul2": "",
# "Cerebras-GPT-6.7B": ""
# "Pythia-6.9B": ""
}
@cached(TTLCache(1, 1800)) # 经过测试缓存的token可以使用目前每30分钟刷新一次
def get_baidu_access_token(api_key: str, secret_key: str) -> str:
"""
使用 AKSK 生成鉴权签名Access Token
:return: access_token或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
try:
return httpx.get(url, params=params).json().get("access_token")
except Exception as e:
print(f"failed to get token from baidu: {e}")
def request_qianfan_api(
messages: List[Dict[str, str]],
temperature: float = TEMPERATURE,
model_name: str = "qianfan-api",
version: str = None,
) -> Dict:
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\
'/{model_version}?access_token={access_token}'
config = get_model_worker_config(model_name)
version = version or config.get("version")
version_url = config.get("version_url")
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
if not access_token:
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
url = BASE_URL.format(
model_version=version_url or MODEL_VERSIONS[version],
access_token=access_token,
)
payload = {
"messages": messages,
"temperature": temperature,
"stream": True
}
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
}
with httpx.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
yield resp
class QianFanWorker(ApiModelWorker):
"""
百度千帆
"""
def __init__(
self,
*,
version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot",
model_names: List[str] = ["ernie-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
config = self.get_config()
self.version = version
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
def generate_stream_gate(self, params):
messages = self.prompt_to_messages(params["prompt"])
text=""
for resp in request_qianfan_api(messages,
temperature=params.get("temperature"),
model_name=self.model_names[0]):
if "result" in resp.keys():
text += resp["result"]
yield json.dumps({
"error_code": 0,
"text": text
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps({
"error_code": resp["error_code"],
"text": resp["error_msg"]
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QianFanWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20006",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20006)

View File

@ -0,0 +1,101 @@
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
from server.model_workers import SparkApi
import websockets
from server.utils import iter_over_async, asyncio
from typing import List
async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature):
# print("星火:")
wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url)
wsUrl = wsParam.create_url()
data = SparkApi.gen_params(appid, domain, question, temperature)
async with websockets.connect(wsUrl) as ws:
await ws.send(json.dumps(data, ensure_ascii=False))
finish = False
while not finish:
chunk = await ws.recv()
response = json.loads(chunk)
if response.get("header", {}).get("status") == 2:
finish = True
if text := response.get("payload", {}).get("choices", {}).get("text"):
yield text[0]["content"]
class XingHuoWorker(ApiModelWorker):
def __init__(
self,
*,
model_names: List[str] = ["xinghuo-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8192)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
# TODO: 当前每次对话都要重新连接websocket确认是否可以保持连接
super().generate_stream_gate(params)
config = self.get_config()
appid = config.get("APPID")
api_secret = config.get("APISecret")
api_key = config.get("api_key")
if config.get("is_v2"):
domain = "generalv2" # v2.0版本
Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
else:
domain = "general" # v1.5版本
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
question = self.prompt_to_messages(params["prompt"])
text = ""
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
for chunk in iter_over_async(
request(appid, api_key, api_secret, Spark_url, domain, question, params.get("temperature")),
loop=loop,
):
if chunk:
print(chunk)
text += chunk
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = XingHuoWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20005",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20005)

View File

@ -1,4 +1,3 @@
import zhipuai
from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
@ -13,7 +12,7 @@ class ChatGLMWorker(ApiModelWorker):
def __init__(
self,
*,
model_names: List[str] = ["chatglm-api"],
model_names: List[str] = ["zhipu-api"],
version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std",
controller_addr: str,
worker_addr: str,
@ -26,7 +25,7 @@ class ChatGLMWorker(ApiModelWorker):
# 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation(
name="chatglm-api",
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["Human", "Assistant"],
@ -35,11 +34,11 @@ class ChatGLMWorker(ApiModelWorker):
)
def generate_stream_gate(self, params):
# TODO: 支持stream参数维护request_id传过来的prompt也有问题
from server.utils import get_model_worker_config
# TODO: 维护request_id
import zhipuai
super().generate_stream_gate(params)
zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
zhipuai.api_key = self.get_config().get("api_key")
response = zhipuai.model_api.sse_invoke(
model=self.version,

View File

@ -4,11 +4,14 @@ from typing import List
from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE, logger, log_verbose
from configs.server_config import FSCHAT_MODEL_WORKERS
import os
from server import model_workers
from typing import Literal, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional, Callable, Generator, Dict, Any
thread_pool = ThreadPoolExecutor(os.cpu_count())
class BaseResponse(BaseModel):
@ -82,9 +85,10 @@ def torch_gc():
from torch.mps import empty_cache
empty_cache()
except Exception as e:
print(e)
print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")
msg=("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
"以支持及时清理 torch 产生的内存占用。")
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
def run_async(cor):
'''
@ -200,6 +204,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
from configs.server_config import FSCHAT_MODEL_WORKERS
from server import model_workers
from configs.model_config import llm_model_dict
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
@ -213,7 +218,9 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
try:
config["worker_class"] = getattr(model_workers, provider)
except Exception as e:
print(f"在线模型 {model_name} 的provider没有正确配置")
msg = f"在线模型 {model_name} 的provider没有正确配置"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
config["device"] = llm_device(config.get("device") or LLM_DEVICE)
return config
@ -305,3 +312,24 @@ def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
pool: ThreadPoolExecutor = None,
) -> Generator:
'''
在线程池中批量运行任务并将运行结果以生成器的形式返回
请确保任务中的所有操作是线程安全的任务函数请全部使用关键字参数
'''
tasks = []
pool = pool or thread_pool
for kwargs in params:
thread = pool.submit(func, **kwargs)
tasks.append(thread)
for obj in as_completed(tasks):
yield obj.result()

View File

@ -3,7 +3,8 @@ import multiprocessing as mp
import os
import subprocess
import sys
from multiprocessing import Process, Queue
from multiprocessing import Process
from datetime import datetime
from pprint import pprint
# 设置numexpr最大线程数默认为CPU核心数
@ -17,9 +18,9 @@ except:
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger
logger, log_verbose, TEXT_SPLITTER
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, )
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout,
get_model_worker_config, get_all_model_worker_configs,
@ -47,7 +48,7 @@ def create_controller_app(
return app
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
@ -87,6 +88,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse
args.limit_worker_concurrency = 5
args.stream_interval = 2
args.no_register = False
args.embed_in_truncate = False
for k, v in kwargs.items():
setattr(args, k, v)
@ -147,6 +149,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
@ -188,29 +191,15 @@ def create_openai_api_app(
return app
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
if q is None or not isinstance(run_seq, int):
return
if run_seq == 1:
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
q.put(run_seq)
elif run_seq > 1:
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
if started_event is not None:
started_event.set()
def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None):
def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import httpx
from fastapi import Body
@ -221,12 +210,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
log_level=log_level,
)
_set_app_seq(app, q, run_seq)
@app.on_event("startup")
def on_startup():
if e is not None:
e.set()
_set_app_event(app, started_event)
# add interface to release and load model worker
@app.post("/release_worker")
@ -266,7 +250,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
return {"code": 500, "msg": msg}
if new_model_name:
timer = 300 # wait 5 minutes for new model_worker register
timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
@ -299,9 +283,9 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev
def run_model_worker(
model_name: str = LLM_MODEL,
controller_address: str = "",
q: Queue = None,
run_seq: int = 2,
log_level: str = "INFO",
q: mp.Queue = None,
started_event: mp.Event = None,
):
import uvicorn
from fastapi import Body
@ -317,7 +301,7 @@ def run_model_worker(
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
_set_app_seq(app, q, run_seq)
_set_app_event(app, started_event)
if log_level == "ERROR":
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
@ -325,29 +309,29 @@ def run_model_worker(
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
q.put(["start", new_model_name])
q.put([model_name, "start", new_model_name])
else:
if new_model_name:
q.put(["replace", new_model_name])
q.put([model_name, "replace", new_model_name])
else:
q.put(["stop"])
q.put([model_name, "stop", None])
return {"code": 200, "msg": "done"}
uvicorn.run(app, host=host, port=port, log_level=log_level.lower())
def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import sys
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
_set_app_seq(app, q, run_seq)
_set_app_event(app, started_event)
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
@ -357,12 +341,12 @@ def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"):
uvicorn.run(app, host=host, port=port)
def run_api_server(q: Queue, run_seq: int = 4):
def run_api_server(started_event: mp.Event = None):
from server.api import create_app
import uvicorn
app = create_app()
_set_app_seq(app, q, run_seq)
_set_app_event(app, started_event)
host = API_SERVER["host"]
port = API_SERVER["port"]
@ -370,21 +354,19 @@ def run_api_server(q: Queue, run_seq: int = 4):
uvicorn.run(app, host=host, port=port)
def run_webui(q: Queue, run_seq: int = 5):
def run_webui(started_event: mp.Event = None):
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
if q is not None and isinstance(run_seq, int):
while True:
no = q.get()
if no != run_seq - 1:
q.put(no)
else:
break
q.put(run_seq)
p = subprocess.Popen(["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port)])
"--server.port", str(port),
"--theme.base", "light",
"--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000",
])
started_event.set()
p.wait()
@ -427,8 +409,9 @@ def parse_args() -> argparse.ArgumentParser:
"-n",
"--model-name",
type=str,
default=LLM_MODEL,
help="specify model name for model worker.",
nargs="+",
default=[LLM_MODEL],
help="specify model name for model worker. add addition names with space seperated to start multiple model workers.",
dest="model_name",
)
parser.add_argument(
@ -483,11 +466,15 @@ def dump_server_info(after_start=False, args=None):
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
model = LLM_MODEL
models = [LLM_MODEL]
if args and args.model_name:
model = args.model_name
print(f"当前LLM模型{model} @ {llm_device()}")
pprint(llm_model_dict[model])
models = args.model_name
print(f"当前使用的分词器:{TEXT_SPLITTER}")
print(f"当前启动的LLM模型{models} @ {llm_device()}")
for model in models:
pprint(llm_model_dict[model])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
@ -554,12 +541,12 @@ async def start_main_server():
logger.info(f"正在启动服务:")
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
processes = {"online-api": []}
processes = {"online_api": {}, "model_worker": {}}
def process_count():
return len(processes) + len(processes["online-api"]) - 1
return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2
if args.quiet:
if args.quiet or not log_verbose:
log_level = "ERROR"
else:
log_level = "INFO"
@ -569,63 +556,73 @@ async def start_main_server():
process = Process(
target=run_controller,
name=f"controller",
args=(queue, process_count() + 1, log_level, controller_started),
kwargs=dict(log_level=log_level, started_event=controller_started),
daemon=True,
)
processes["controller"] = process
process = Process(
target=run_openai_api,
name=f"openai_api",
args=(queue, process_count() + 1),
daemon=True,
)
processes["openai_api"] = process
model_worker_started = []
if args.model_worker:
config = get_model_worker_config(args.model_name)
if not config.get("online_api"):
process = Process(
target=run_model_worker,
name=f"model_worker - {args.model_name}",
args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level),
daemon=True,
)
processes["model_worker"] = process
for model_name in args.model_name:
config = get_model_worker_config(model_name)
if not config.get("online_api"):
e = manager.Event()
model_worker_started.append(e)
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name}",
kwargs=dict(model_name=model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
processes["model_worker"][model_name] = process
if args.api_worker:
configs = get_all_model_worker_configs()
for model_name, config in configs.items():
if config.get("online_api") and config.get("worker_class"):
e = manager.Event()
model_worker_started.append(e)
process = Process(
target=run_model_worker,
name=f"model_worker - {model_name}",
args=(model_name, args.controller_address, queue, process_count() + 1, log_level),
name=f"api_worker - {model_name}",
kwargs=dict(model_name=model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
processes["online_api"][model_name] = process
processes["online-api"].append(process)
api_started = manager.Event()
if args.api:
process = Process(
target=run_api_server,
name=f"API Server",
args=(queue, process_count() + 1),
kwargs=dict(started_event=api_started),
daemon=True,
)
processes["api"] = process
webui_started = manager.Event()
if args.webui:
process = Process(
target=run_webui,
name=f"WEBUI Server",
args=(queue, process_count() + 1),
kwargs=dict(started_event=webui_started),
daemon=True,
)
processes["webui"] = process
if process_count() == 0:
@ -636,60 +633,106 @@ async def start_main_server():
if p:= processes.get("controller"):
p.start()
p.name = f"{p.name} ({p.pid})"
controller_started.wait()
controller_started.wait() # 等待controller启动完成
if p:= processes.get("openai_api"):
p.start()
p.name = f"{p.name} ({p.pid})"
if p:= processes.get("model_worker"):
for n, p in processes.get("model_worker", {}).items():
p.start()
p.name = f"{p.name} ({p.pid})"
for p in processes.get("online-api", []):
for n, p in processes.get("online_api", []).items():
p.start()
p.name = f"{p.name} ({p.pid})"
# 等待所有model_worker启动完成
for e in model_worker_started:
e.wait()
if p:= processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
api_started.wait() # 等待api.py启动完成
if p:= processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args)
while True:
no = queue.get()
if no == process_count():
time.sleep(0.5)
dump_server_info(after_start=True, args=args)
break
else:
queue.put(no)
cmd = queue.get() # 收到切换模型的消息
e = manager.Event()
if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd
if cmd == "start": # 运行新模型
logger.info(f"准备启动新模型进程:{new_model_name}")
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
logger.info(f"成功启动新模型进程:{new_model_name}")
elif cmd == "stop":
if process := processes["model_worker"].get(model_name):
time.sleep(1)
process.terminate()
process.join()
logger.info(f"停止模型进程:{model_name}")
else:
logger.error(f"未找到模型进程:{model_name}")
elif cmd == "replace":
if process := processes["model_worker"].pop(model_name, None):
logger.info(f"停止模型进程:{model_name}")
start_time = datetime.now()
time.sleep(1)
process.terminate()
process.join()
process = Process(
target=run_model_worker,
name=f"model_worker - {new_model_name}",
kwargs=dict(model_name=new_model_name,
controller_address=args.controller_address,
log_level=log_level,
q=queue,
started_event=e),
daemon=True,
)
process.start()
process.name = f"{process.name} ({process.pid})"
processes["model_worker"][new_model_name] = process
e.wait()
timing = datetime.now() - start_time
logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}")
else:
logger.error(f"未找到模型进程:{model_name}")
if model_worker_process := processes.get("model_worker"):
model_worker_process.join()
for process in processes.get("online-api", []):
process.join()
for name, process in processes.items():
if name not in ["model_worker", "online-api"]:
if isinstance(p, list):
for work_process in p:
work_process.join()
else:
process.join()
# for process in processes.get("model_worker", {}).values():
# process.join()
# for process in processes.get("online_api", {}).values():
# process.join()
# for name, process in processes.items():
# if name not in ["model_worker", "online_api"]:
# if isinstance(p, dict):
# for work_process in p.values():
# work_process.join()
# else:
# process.join()
except Exception as e:
# if model_worker_process := processes.pop("model_worker", None):
# model_worker_process.terminate()
# for process in processes.pop("online-api", []):
# process.terminate()
# for process in processes.values():
#
# if isinstance(process, list):
# for work_process in process:
# work_process.terminate()
# else:
# process.terminate()
logger.error(e)
logger.warning("Caught KeyboardInterrupt! Setting stop event...")
finally:
@ -702,10 +745,9 @@ async def start_main_server():
# Queues and other inter-process communication primitives can break when
# process is killed, but we don't care here
if isinstance(p, list):
for process in p:
if isinstance(p, dict):
for process in p.values():
process.kill()
else:
p.kill()

View File

@ -7,19 +7,23 @@ root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path
from server.knowledge_base.utils import get_kb_path, get_file_path
from pprint import pprint
api_base_url = api_address()
kb = "kb_for_api_test"
test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
"README.MD": str(root_path / "README.MD"),
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
"test.txt": get_file_path("samples", "test.txt"),
}
print("\n\n直接url访问\n")
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
if not Path(get_kb_path(kb)).exists():
@ -78,37 +82,36 @@ def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
assert kb in data["data"]
def test_upload_doc(api="/knowledge_base/upload_doc"):
def test_upload_docs(api="/knowledge_base/upload_docs"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n上传知识文件: {name}")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"文件 {name} 已存在。"
print(f"\n上传知识文件")
data = {"knowledge_base_name": kb, "override": True}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
for name, path in test_files.items():
print(f"\n尝试重新上传知识文件: {name} 覆盖")
data = {"knowledge_base_name": kb, "override": True}
files = {"file": (name, open(path, "rb"))}
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功上传文件 {name}"
print(f"\n尝试重新上传知识文件, 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == len(test_files)
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()]
r = requests.post(url, data=data, files=files)
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_list_files(api="/knowledge_base/list_files"):
@ -134,26 +137,26 @@ def test_search_docs(api="/knowledge_base/search_docs"):
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_doc(api="/knowledge_base/update_doc"):
def test_update_docs(api="/knowledge_base/update_docs"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n更新知识文件 {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"成功更新文件 {name}"
print(f"\n更新知识文件")
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_delete_doc(api="/knowledge_base/delete_doc"):
def test_delete_docs(api="/knowledge_base/delete_docs"):
url = api_base_url + api
for name, path in test_files.items():
print(f"\n删除知识文件 {name}")
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
data = r.json()
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"{name} 文件删除成功"
print(f"\n删除知识文件")
r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)})
data = r.json()
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
url = api_base_url + "/knowledge_base/search_docs"
query = "介绍一下langchain-chatchat项目"

View File

@ -0,0 +1,161 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest
from pprint import pprint
api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False)
kb = "kb_for_api_test"
test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
"README.MD": str(root_path / "README.MD"),
"test.txt": get_file_path("samples", "test.txt"),
}
print("\n\nApiRquest调用\n")
def test_delete_kb_before():
if not Path(get_kb_path(kb)).exists():
return
data = api.delete_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]
def test_create_kb():
print(f"\n尝试用空名称创建知识库:")
data = api.create_knowledge_base(" ")
pprint(data)
assert data["code"] == 404
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
print(f"\n创建新知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"已新增知识库 {kb}"
print(f"\n尝试创建同名知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"已存在同名知识库 {kb}"
def test_list_kbs():
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb in data
def test_upload_docs():
files = list(test_files.values())
print(f"\n上传知识文件")
data = {"knowledge_base_name": kb, "override": True}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
print(f"\n尝试重新上传知识文件, 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == len(test_files)
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": docs}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_list_files():
print("\n获取知识库中文件列表:")
data = api.list_kb_docs(knowledge_base_name=kb)
pprint(data)
assert isinstance(data, list)
for name in test_files:
assert name in data
def test_search_docs():
query = "介绍一下langchain-chatchat项目"
print("\n检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_docs():
print(f"\n更新知识文件")
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_delete_docs():
print(f"\n删除知识文件")
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
query = "介绍一下langchain-chatchat项目"
print("\n尝试检索删除后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == 0
def test_recreate_vs():
print("\n重建知识库:")
r = api.recreate_vector_store(kb)
for data in r:
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
query = "本项目支持哪些文件格式?"
print("\n尝试检索重建后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_delete_kb_after():
print("\n删除知识库")
data = api.delete_knowledge_base(kb)
pprint(data)
# check kb not exists anymore
print("\n获取知识库列表:")
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb not in data

View File

@ -5,8 +5,9 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import api_address, FSCHAT_MODEL_WORKERS
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict
from server.utils import api_address, get_model_worker_config
from pprint import pprint
import random
@ -64,7 +65,8 @@ def test_change_model(api="/llm_model/change"):
assert len(availabel_new_models) > 0
print(availabel_new_models)
model_name = random.choice(running_models)
local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")]
model_name = random.choice(local_models)
new_model_name = random.choice(availabel_new_models)
print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}")
r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name})

View File

@ -43,11 +43,11 @@ data = {
"content": "你好,我是人工智能大模型"
}
],
"stream": True
"stream": True,
"temperature": 0.7,
}
def test_chat_fastchat(api="/chat/fastchat"):
url = f"{api_base_url}{api}"
data2 = {
@ -89,14 +89,12 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
response = requests.post(url, headers=headers, json=data, stream=True)
print("\n")
print("=" * 30 + api + " output" + "="*30)
first = True
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if first:
for doc in data["docs"]:
print(doc)
first = False
print(data["answer"], end="", flush=True)
if "anser" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
assert response.status_code == 200
@ -117,14 +115,11 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
print("\n")
print("=" * 30 + api + " by {se} output" + "="*30)
first = True
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
assert "docs" in data and len(data["docs"]) > 0
if first:
for doc in data.get("docs", []):
print(doc)
first = False
print(data["answer"], end="", flush=True)
if "answer" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
assert response.status_code == 200

View File

@ -0,0 +1,53 @@
import os
from transformers import AutoTokenizer
import sys
sys.path.append("../..")
from configs.model_config import (
CHUNK_SIZE,
OVERLAP_SIZE
)
from server.knowledge_base.utils import make_text_splitter
def text(splitter_name):
from langchain import document_loaders
# 使用DocumentLoader读取文件
filepath = "../../knowledge_base/samples/content/test.txt"
loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True)
docs = loader.load()
text_splitter = make_text_splitter(splitter_name, CHUNK_SIZE, OVERLAP_SIZE)
if splitter_name == "MarkdownHeaderTextSplitter":
docs = text_splitter.split_text(docs[0].page_content)
for doc in docs:
if doc.metadata:
doc.metadata["source"] = os.path.basename(filepath)
else:
docs = text_splitter.split_documents(docs)
for doc in docs:
print(doc)
return docs
import pytest
from langchain.docstore.document import Document
@pytest.mark.parametrize("splitter_name",
[
"ChineseRecursiveTextSplitter",
"SpacyTextSplitter",
"RecursiveCharacterTextSplitter",
"MarkdownHeaderTextSplitter"
])
def test_different_splitter(splitter_name):
try:
docs = text(splitter_name)
assert isinstance(docs, list)
if len(docs)>0:
assert isinstance(docs[0], Document)
except Exception as e:
pytest.fail(f"test_different_splitter failed with {splitter_name}, error: {str(e)}")

View File

@ -6,14 +6,14 @@ sys.path.append(str(root_path))
from pprint import pprint
test_files = {
"ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"),
"ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"),
}
def test_rapidocrpdfloader():
pdf_path = test_files["ocr_test.pdf"]
from document_loaders import RapidOCRPDFLoader
def test_rapidocrloader():
img_path = test_files["ocr_test.jpg"]
from document_loaders import RapidOCRLoader
loader = RapidOCRPDFLoader(pdf_path)
loader = RapidOCRLoader(img_path)
docs = loader.load()
pprint(docs)
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)

View File

@ -6,14 +6,14 @@ sys.path.append(str(root_path))
from pprint import pprint
test_files = {
"ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"),
"ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"),
}
def test_rapidocrloader():
img_path = test_files["ocr_test.jpg"]
from document_loaders import RapidOCRLoader
def test_rapidocrpdfloader():
pdf_path = test_files["ocr_test.pdf"]
from document_loaders import RapidOCRPDFLoader
loader = RapidOCRLoader(img_path)
loader = RapidOCRPDFLoader(pdf_path)
docs = loader.load()
pprint(docs)
assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str)

View File

@ -0,0 +1,20 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.qianfan import request_qianfan_api, MODEL_VERSIONS
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", MODEL_VERSIONS.keys())
def test_qianfan(version):
messages = [{"role": "user", "content": "你好"}]
print("\n" + version + "\n")
i = 1
for x in request_qianfan_api(messages, version=version):
pprint(x)
assert isinstance(x, dict)
assert "error_code" not in x
i += 1

View File

@ -1,3 +1,4 @@
from .chinese_text_splitter import ChineseTextSplitter
from .ali_text_splitter import AliTextSplitter
from .zh_title_enhance import zh_title_enhance
from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter

View File

@ -0,0 +1,104 @@
import re
from typing import List, Optional, Any
from langchain.text_splitter import RecursiveCharacterTextSplitter
import logging
logger = logging.getLogger(__name__)
def _split_text_with_regex_from_end(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
if len(_splits) % 2 == 1:
splits += _splits[-1:]
# splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = True,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or [
"\n\n",
"\n",
"。||",
"\.\s|\!\s|\?\s",
"|;\s",
"|,\s"
]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1:]
break
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""]
if __name__ == "__main__":
text_splitter = ChineseRecursiveTextSplitter(
keep_separator=True,
is_separator_regex=True,
chunk_size=50,
chunk_overlap=0
)
ls = [
"""中国对外贸易形势报告75页。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1% 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点进口8.9万亿元增长24.9%占进口总额的62.7% 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8% 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""",
]
# text = """"""
for inum, text in enumerate(ls):
print(inum)
chunks = text_splitter.split_text(text)
for chunk in chunks:
print(chunk)

View File

@ -1,11 +1,10 @@
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
from configs.model_config import CHUNK_SIZE
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = CHUNK_SIZE, **kwargs):
def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size

View File

@ -1,9 +1,3 @@
# 运行方式:
# 1. 安装必要的包pip install streamlit-option-menu streamlit-chatbox>=1.1.6
# 2. 运行本机fastchat服务python server\llm_api.py 或者 运行对应的sh文件
# 3. 运行API服务器python server/api.py。如果使用api = ApiRequest(no_remote_api=True),该步可以跳过。
# 4. 运行WEB UIstreamlit run webui.py --server.port 7860
import streamlit as st
from webui_pages.utils import *
from streamlit_option_menu import option_menu

View File

@ -5,7 +5,7 @@ from streamlit_chatbox import *
from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
import os
from configs.model_config import llm_model_dict, LLM_MODEL
from configs.model_config import LLM_MODEL, TEMPERATURE
from server.utils import get_model_worker_config
from typing import List, Dict
@ -55,17 +55,20 @@ def dialogue_page(api: ApiRequest):
st.toast(text)
# sac.alert(text, description="descp", type="success", closable=True, banner=True)
dialogue_mode = st.selectbox("请选择对话模式",
dialogue_mode = st.selectbox("请选择对话模式",
["LLM 对话",
"知识库问答",
"搜索引擎问答",
],
index=1,
on_change=on_mode_change,
key="dialogue_mode",
)
def on_llm_change():
st.session_state["prev_llm_model"] = llm_model
config = get_model_worker_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model
def llm_model_format_func(x):
if x in running_models:
@ -78,10 +81,8 @@ def dialogue_page(api: ApiRequest):
if x in config_models:
config_models.remove(x)
llm_models = running_models + config_models
if "prev_llm_model" not in st.session_state:
index = llm_models.index(LLM_MODEL)
else:
index = 0
cur_model = st.session_state.get("cur_llm_model", LLM_MODEL)
index = llm_models.index(cur_model)
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
@ -91,10 +92,11 @@ def dialogue_page(api: ApiRequest):
)
if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")):
with st.spinner(f"正在加载模型: {llm_model}"):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = llm_model
temperature = st.slider("Temperature", 0.0, 1.0, TEMPERATURE, 0.05)
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
def on_kb_change():
@ -110,7 +112,7 @@ def dialogue_page(api: ApiRequest):
key="selected_kb",
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == "搜索引擎问答":
@ -135,7 +137,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
r = api.chat_chat(prompt, history=history, model=llm_model)
r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
@ -150,28 +152,38 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"),
])
text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured
for d in api.knowledge_base_chat(prompt,
knowledge_base_name=selected_kb,
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
else:
text += d["answer"]
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, 0)
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
chat_box.update_msg(text, 0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), 1, streaming=False)
elif dialogue_mode == "搜索引擎问答":
chat_box.ai_say([
f"正在执行 `{search_engine}` 搜索...",
Markdown("...", in_expander=True, title="网络搜索结果"),
])
text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
if error_msg := check_error_msg(d): # check whether error occured
for d in api.search_engine_chat(prompt,
search_engine_name=search_engine,
top_k=se_top_k,
model=llm_model,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
else:
text += d["answer"]
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, 0)
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
chat_box.update_msg(text, 0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), 1, streaming=False)
now = datetime.now()
with st.sidebar:

View File

@ -6,7 +6,9 @@ import pandas as pd
from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple
from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE
from configs.model_config import (embedding_model_dict, kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
import os
import time
@ -125,28 +127,41 @@ def knowledge_base_page(api: ApiRequest):
elif selected_kb:
kb = selected_kb
# 上传文件
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
files = st.file_uploader("上传知识文件",
files = st.file_uploader("上传知识文件:",
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)
# with st.sidebar:
with st.expander(
"文件处理配置",
expanded=True,
):
cols = st.columns(3)
chunk_size = cols[0].number_input("单段文本最大长度:", 1, 1000, CHUNK_SIZE)
chunk_overlap = cols[1].number_input("相邻文本重合长度:", 0, chunk_size, OVERLAP_SIZE)
cols[2].write("")
cols[2].write("")
zh_title_enhance = cols[2].checkbox("开启中文标题加强", ZH_TITLE_ENHANCE)
if st.button(
"添加文件到知识库",
# help="请先上传文件,再点击添加",
# use_container_width=True,
disabled=len(files) == 0,
):
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
data[-1]["not_refresh_vs_cache"]=False
for k in data:
ret = api.upload_kb_doc(**k)
if msg := check_success_msg(ret):
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
st.session_state.files = []
ret = api.upload_kb_docs(files,
knowledge_base_name=kb,
override=True,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance)
if msg := check_success_msg(ret):
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
st.divider()
@ -160,7 +175,7 @@ def knowledge_base_page(api: ApiRequest):
st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作")
doc_details.drop(columns=["kb_name"], inplace=True)
doc_details = doc_details[[
"No", "file_name", "document_loader", "docs_count", "in_folder", "in_db",
"No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db",
]]
# doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×")
# doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×")
@ -173,7 +188,7 @@ def knowledge_base_page(api: ApiRequest):
# ("file_version", "文档版本"): {},
("document_loader", "文档加载器"): {},
("docs_count", "文档数量"): {},
# ("text_splitter", "分词器"): {},
("text_splitter", "分词器"): {},
# ("create_time", "创建时间"): {},
("in_folder", "源文件"): {"cellRenderer": cell_renderer},
("in_db", "向量库"): {"cellRenderer": cell_renderer},
@ -218,8 +233,12 @@ def knowledge_base_page(api: ApiRequest):
disabled=not file_exists(kb, selected_rows)[0],
use_container_width=True,
):
for row in selected_rows:
api.update_kb_doc(kb, row["file_name"])
file_names = [row["file_name"] for row in selected_rows]
api.update_kb_docs(kb,
file_names=file_names,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance)
st.experimental_rerun()
# 将文件从向量库中删除,但不删除文件本身。
@ -228,8 +247,8 @@ def knowledge_base_page(api: ApiRequest):
disabled=not (selected_rows and selected_rows[0]["in_db"]),
use_container_width=True,
):
for row in selected_rows:
api.delete_kb_doc(kb, row["file_name"])
file_names = [row["file_name"] for row in selected_rows]
api.delete_kb_docs(kb, file_names=file_names)
st.experimental_rerun()
if cols[3].button(
@ -237,9 +256,8 @@ def knowledge_base_page(api: ApiRequest):
type="primary",
use_container_width=True,
):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"], True)
st.toast(ret.get("msg", " "))
file_names = [row["file_name"] for row in selected_rows]
api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
st.experimental_rerun()
st.divider()
@ -255,7 +273,10 @@ def knowledge_base_page(api: ApiRequest):
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
empty = st.empty()
empty.progress(0.0, "")
for d in api.recreate_vector_store(kb):
for d in api.recreate_vector_store(kb,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if msg := check_error_msg(d):
st.toast(msg)
else:

View File

@ -8,10 +8,14 @@ from configs.model_config import (
LLM_MODEL,
llm_model_dict,
HISTORY_LEN,
TEMPERATURE,
SCORE_THRESHOLD,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
logger,
logger, log_verbose,
)
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx
@ -20,10 +24,9 @@ from server.chat.openai_chat import OpenAiChatMsgIn
from fastapi.responses import StreamingResponse
import contextlib
import json
import os
from io import BytesIO
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.db.repository.knowledge_file_repository import get_file_detail
from server.utils import run_async, iter_over_async, set_httpx_timeout
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
from configs.model_config import NLTK_DATA_PATH
import nltk
@ -43,13 +46,15 @@ class ApiRequest:
'''
def __init__(
self,
base_url: str = "http://127.0.0.1:7861",
base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT,
no_remote_api: bool = False, # call api view function directly
):
self.base_url = base_url
self.timeout = timeout
self.no_remote_api = no_remote_api
if no_remote_api:
logger.warn("将来可能取消对no_remote_api的支持更新版本时请注意。")
def _parse_url(self, url: str) -> str:
if (not url.startswith("http")
@ -78,7 +83,9 @@ class ApiRequest:
else:
return httpx.get(url, params=params, **kwargs)
except Exception as e:
logger.error(e)
msg = f"error when get {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
async def aget(
@ -99,7 +106,9 @@ class ApiRequest:
else:
return await client.get(url, params=params, **kwargs)
except Exception as e:
logger.error(e)
msg = f"error when aget {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def post(
@ -121,7 +130,9 @@ class ApiRequest:
else:
return httpx.post(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
msg = f"error when post {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
async def apost(
@ -143,7 +154,9 @@ class ApiRequest:
else:
return await client.post(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
msg = f"error when apost {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def delete(
@ -164,7 +177,9 @@ class ApiRequest:
else:
return httpx.delete(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
msg = f"error when delete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
async def adelete(
@ -186,7 +201,9 @@ class ApiRequest:
else:
return await client.delete(url, data=data, json=json, **kwargs)
except Exception as e:
logger.error(e)
msg = f"error when adelete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
@ -197,7 +214,7 @@ class ApiRequest:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
try:
for chunk in iter_over_async(response.body_iterator, loop):
if as_json and chunk:
@ -205,7 +222,9 @@ class ApiRequest:
elif chunk.strip():
yield chunk
except Exception as e:
logger.error(e)
msg = f"error when run fastapi router: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
def _httpx_stream2generator(
self,
@ -226,23 +245,26 @@ class ApiRequest:
pprint(data, depth=1)
yield data
except Exception as e:
logger.error(f"接口返回json错误 {chunk}’。错误信息是:{e}")
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
else:
print(chunk, end="", flush=True)
yield chunk
except httpx.ConnectError as e:
msg = f"无法连接API服务器请确认 api.py 已正常启动。"
msg = f"无法连接API服务器请确认 api.py 已正常启动。({e})"
logger.error(msg)
logger.error(msg)
logger.error(e)
yield {"code": 500, "msg": msg}
except httpx.ReadTimeout as e:
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'"
msg = f"API通信超时请确认已启动FastChat与API服务详见RADME '5. 启动 API 服务或 Web UI'。({e}"
logger.error(msg)
logger.error(e)
yield {"code": 500, "msg": msg}
except Exception as e:
logger.error(e)
yield {"code": 500, "msg": str(e)}
msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
# 对话相关操作
@ -251,7 +273,7 @@ class ApiRequest:
messages: List[Dict],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = 0.7,
temperature: float = TEMPERATURE,
max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens
no_remote_api: bool = None,
**kwargs: Any,
@ -272,7 +294,7 @@ class ApiRequest:
if no_remote_api:
from server.chat.openai_chat import openai_chat
response = openai_chat(msg)
response = run_async(openai_chat(msg))
return self._fastapi_stream2generator(response)
else:
data = msg.dict(exclude_unset=True, exclude_none=True)
@ -282,7 +304,7 @@ class ApiRequest:
response = self.post(
"/chat/fastchat",
json=data,
stream=stream,
stream=True,
)
return self._httpx_stream2generator(response)
@ -292,6 +314,7 @@ class ApiRequest:
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
no_remote_api: bool = None,
):
'''
@ -305,6 +328,7 @@ class ApiRequest:
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
}
print(f"received input message:")
@ -312,7 +336,7 @@ class ApiRequest:
if no_remote_api:
from server.chat.chat import chat
response = chat(**data)
response = run_async(chat(**data))
return self._fastapi_stream2generator(response)
else:
response = self.post("/chat/chat", json=data, stream=True)
@ -327,6 +351,7 @@ class ApiRequest:
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
no_remote_api: bool = None,
):
'''
@ -343,6 +368,7 @@ class ApiRequest:
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"local_doc_url": no_remote_api,
}
@ -351,7 +377,7 @@ class ApiRequest:
if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(**data)
response = run_async(knowledge_base_chat(**data))
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post(
@ -368,6 +394,7 @@ class ApiRequest:
top_k: int = SEARCH_ENGINE_TOP_K,
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
no_remote_api: bool = None,
):
'''
@ -382,6 +409,7 @@ class ApiRequest:
"top_k": top_k,
"stream": stream,
"model_name": model,
"temperature": temperature,
}
print(f"received input message:")
@ -389,7 +417,7 @@ class ApiRequest:
if no_remote_api:
from server.chat.search_engine_chat import search_engine_chat
response = search_engine_chat(**data)
response = run_async(search_engine_chat(**data))
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post(
@ -413,8 +441,10 @@ class ApiRequest:
try:
return response.json()
except Exception as e:
logger.error(e)
return {"code": 500, "msg": errorMsg or str(e)}
msg = "API未能返回正确的JSON。" + (errorMsg or str(e))
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg}
def list_knowledge_bases(
self,
@ -428,7 +458,7 @@ class ApiRequest:
if no_remote_api:
from server.knowledge_base.kb_api import list_kbs
response = run_async(list_kbs())
response = list_kbs()
return response.data
else:
response = self.get("/knowledge_base/list_knowledge_bases")
@ -456,7 +486,7 @@ class ApiRequest:
if no_remote_api:
from server.knowledge_base.kb_api import create_kb
response = run_async(create_kb(**data))
response = create_kb(**data)
return response.dict()
else:
response = self.post(
@ -478,7 +508,7 @@ class ApiRequest:
if no_remote_api:
from server.knowledge_base.kb_api import delete_kb
response = run_async(delete_kb(knowledge_base_name))
response = delete_kb(knowledge_base_name)
return response.dict()
else:
response = self.post(
@ -500,7 +530,7 @@ class ApiRequest:
if no_remote_api:
from server.knowledge_base.kb_doc_api import list_files
response = run_async(list_files(knowledge_base_name))
response = list_files(knowledge_base_name)
return response.data
else:
response = self.get(
@ -510,12 +540,48 @@ class ApiRequest:
data = self._check_httpx_json_response(response)
return data.get("data", [])
def upload_kb_doc(
def search_kb_docs(
self,
file: Union[str, Path, bytes],
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
no_remote_api: bool = None,
) -> List:
'''
对应api.py/knowledge_base/search_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import search_docs
return search_docs(**data)
else:
response = self.post(
"/knowledge_base/search_docs",
json=data,
)
data = self._check_httpx_json_response(response)
return data
def upload_kb_docs(
self,
files: List[Union[str, Path, bytes]],
knowledge_base_name: str,
filename: str = None,
override: bool = False,
to_vector_store: bool = True,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
@ -525,97 +591,122 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or file.name
def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1]
return filename, file
files = [convert_file(file) for file in files]
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"to_vector_store": to_vector_store,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"zh_title_enhance": zh_title_enhance,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import upload_doc
from server.knowledge_base.kb_doc_api import upload_docs
from fastapi import UploadFile
from tempfile import SpooledTemporaryFile
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read())
temp_file.seek(0)
response = run_async(upload_doc(
UploadFile(file=temp_file, filename=filename),
knowledge_base_name,
override,
))
upload_files = []
for filename, file in files:
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read())
temp_file.seek(0)
upload_files.append(UploadFile(file=temp_file, filename=filename))
response = upload_docs(upload_files, **data)
return response.dict()
else:
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post(
"/knowledge_base/upload_doc",
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
files={"file": (filename, file)},
"/knowledge_base/upload_docs",
data=data,
files=[("files", (filename, file)) for filename, file in files],
)
return self._check_httpx_json_response(response)
def delete_kb_doc(
def delete_kb_docs(
self,
knowledge_base_name: str,
doc_name: str,
file_names: List[str],
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
对应api.py/knowledge_base/delete_doc接口
对应api.py/knowledge_base/delete_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"knowledge_base_name": knowledge_base_name,
"doc_name": doc_name,
"file_names": file_names,
"delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import delete_doc
response = run_async(delete_doc(**data))
from server.knowledge_base.kb_doc_api import delete_docs
response = delete_docs(**data)
return response.dict()
else:
response = self.post(
"/knowledge_base/delete_doc",
"/knowledge_base/delete_docs",
json=data,
)
return self._check_httpx_json_response(response)
def update_kb_doc(
def update_kb_docs(
self,
knowledge_base_name: str,
file_name: str,
file_names: List[str],
override_custom_docs: bool = False,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
对应api.py/knowledge_base/update_doc接口
对应api.py/knowledge_base/update_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"knowledge_base_name": knowledge_base_name,
"file_names": file_names,
"override_custom_docs": override_custom_docs,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"zh_title_enhance": zh_title_enhance,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import update_doc
response = run_async(update_doc(knowledge_base_name, file_name))
from server.knowledge_base.kb_doc_api import update_docs
response = update_docs(**data)
return response.dict()
else:
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post(
"/knowledge_base/update_doc",
json={
"knowledge_base_name": knowledge_base_name,
"file_name": file_name,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
"/knowledge_base/update_docs",
json=data,
)
return self._check_httpx_json_response(response)
@ -625,6 +716,9 @@ class ApiRequest:
allow_empty_kb: bool = True,
vs_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
no_remote_api: bool = None,
):
'''
@ -638,11 +732,14 @@ class ApiRequest:
"allow_empty_kb": allow_empty_kb,
"vs_type": vs_type,
"embed_model": embed_model,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"zh_title_enhance": zh_title_enhance,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import recreate_vector_store
response = run_async(recreate_vector_store(**data))
response = recreate_vector_store(**data)
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post(
@ -653,14 +750,30 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
def list_running_models(self, controller_address: str = None):
# LLM模型相关操作
def list_running_models(
self,
controller_address: str = None,
no_remote_api: bool = None,
):
'''
获取Fastchat中正运行的模型列表
'''
r = self.post(
"/llm_model/list_models",
)
return r.json().get("data", [])
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"controller_address": controller_address,
}
if no_remote_api:
from server.llm_api import list_llm_models
return list_llm_models(**data).data
else:
r = self.post(
"/llm_model/list_models",
json=data,
)
return r.json().get("data", [])
def list_config_models(self):
'''
@ -672,30 +785,43 @@ class ApiRequest:
self,
model_name: str,
controller_address: str = None,
no_remote_api: bool = None,
):
'''
停止某个LLM模型
注意由于Fastchat的实现方式实际上是把LLM模型所在的model_worker停掉
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"model_name": model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/stop",
json=data,
)
return r.json()
if no_remote_api:
from server.llm_api import stop_llm_model
return stop_llm_model(**data).dict()
else:
r = self.post(
"/llm_model/stop",
json=data,
)
return r.json()
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
no_remote_api: bool = None,
):
'''
向fastchat controller请求切换LLM模型
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
if not model_name or not new_model_name:
return
@ -724,12 +850,17 @@ class ApiRequest:
"new_model_name": new_model_name,
"controller_address": controller_address,
}
r = self.post(
"/llm_model/change",
json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
if no_remote_api:
from server.llm_api import change_llm_model
return change_llm_model(**data).dict()
else:
r = self.post(
"/llm_model/change",
json=data,
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: