diff --git a/.gitignore b/.gitignore index f5bd3e4b..eac0805d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__/ /configs/*.py .vscode/ .pytest_cache/ +*.bak diff --git a/README.md b/README.md index 453a4913..93784df7 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,7 @@ * [2. 下载模型至本地](README.md#2.-下载模型至本地) * [3. 设置配置项](README.md#3.-设置配置项) * [4. 知识库初始化与迁移](README.md#4.-知识库初始化与迁移) - * [5. 一键启动API服务或WebUI服务](README.md#6.-一键启动) - * [6. 分步启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI) + * [5. 一键启动 API 服务或 Web UI](README.md#5.-一键启动-API-服务或-Web-UI) * [常见问题](README.md#常见问题) * [路线图](README.md#路线图) * [项目交流群](README.md#项目交流群) @@ -78,7 +77,9 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch ### LLM 模型支持 -本项目最新版本中基于 [FastChat](https://github.com/lm-sys/FastChat) 进行本地 LLM 模型接入,支持模型如下: +本项目最新版本中支持接入**本地模型**与**在线 LLM API**。 + +本地 LLM 模型接入基于 [FastChat](https://github.com/lm-sys/FastChat) 实现,支持模型如下: - [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) - Vicuna, Alpaca, LLaMA, Koala @@ -109,12 +110,27 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch - [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) - [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) - [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta) +- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and others +- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) +- [all models of OpenOrca](https://huggingface.co/Open-Orca) +- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2) +- [VMware's OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct) - 任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b) - 在以上模型基础上训练的任何 [Peft](https://github.com/huggingface/peft) 适配器。为了激活,模型路径中必须有 `peft` 。注意:如果加载多个peft模型,你可以通过在任何模型工作器中设置环境变量 `PEFT_SHARE_BASE_WEIGHTS=true` 来使它们共享基础模型的权重。 以上模型支持列表可能随 [FastChat](https://github.com/lm-sys/FastChat) 更新而持续更新,可参考 [FastChat 已支持模型列表](https://github.com/lm-sys/FastChat/blob/main/docs/model_support.md)。 -除本地模型外,本项目也支持直接接入 OpenAI API,具体设置可参考 `configs/model_configs.py.example` 中的 `llm_model_dict` 的 `openai-chatgpt-3.5` 配置信息。 + +除本地模型外,本项目也支持直接接入 OpenAI API、智谱AI等在线模型,具体设置可参考 `configs/model_configs.py.example` 中的 `llm_model_dict` 的配置信息。 + +在线 LLM 模型目前已支持: +- [ChatGPT](https://api.openai.com) +- [智谱AI](http://open.bigmodel.cn) +- [MiniMax](https://api.minimax.chat) +- [讯飞星火](https://xinghuo.xfyun.cn) +- [百度千帆](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan) + +项目中默认使用的 LLM 类型为 `THUDM/chatglm2-6b`,如需使用其他 LLM 类型,请在 [configs/model_config.py] 中对 `llm_model_dict` 和 `LLM_MODEL` 进行修改。 ### Embedding 模型支持 @@ -139,8 +155,34 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch - [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh) - [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings) +项目中默认使用的 Embedding 类型为 `moka-ai/m3e-base`,如需使用其他 Embedding 类型,请在 [configs/model_config.py] 中对 `embedding_model_dict` 和 `EMBEDDING_MODEL` 进行修改。 + --- +### Text Splitter 个性化支持 + +本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的 Text Splitter 分词器以及基于此改进的自定义分词器,已支持的 Text Splitter 类型如下: + +- CharacterTextSplitter +- LatexTextSplitter +- MarkdownHeaderTextSplitter +- MarkdownTextSplitter +- NLTKTextSplitter +- PythonCodeTextSplitter +- RecursiveCharacterTextSplitter +- SentenceTransformersTokenTextSplitter +- SpacyTextSplitter + +已经支持的定制分词器如下: + +- [AliTextSplitter](text_splitter/ali_text_splitter.py) +- [ChineseRecursiveTextSplitter](text_splitter/chinese_recursive_text_splitter.py) +- [ChineseTextSplitter](text_splitter/chinese_text_splitter.py) + +项目中默认使用的 Text Splitter 类型为 `ChineseRecursiveTextSplitter`,如需使用其他 Text Splitter 类型,请在 [configs/model_config.py] 中对 `text_splitter_dict` 和 `TEXT_SPLITTER` 进行修改。 + +关于如何使用自定义分词器和贡献自己的分词器,可以参考[Text Splitter 贡献说明](docs/splitter.md)。 + ## Docker 部署 🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)` @@ -212,6 +254,17 @@ embedding_model_dict = { } ``` +- 请确认本地分词器路径是否已经填写,如: + +```python +text_splitter_dict = { + "ChineseRecursiveTextSplitter": { + "source": "huggingface", ## 选择tiktoken则使用openai的方法,不填写则默认为字符长度切割方法。 + "tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。 + } +} +``` + 如果你选择使用OpenAI的Embedding模型,请将模型的 ``key``写入 `embedding_model_dict`中。使用该模型,你需要能够访问OpenAI官的API,或设置代理。 ### 4. 知识库初始化与迁移 @@ -229,7 +282,7 @@ embedding_model_dict = { $ python init_database.py --recreate-vs ``` -### 5. 一键启动API 服务或 Web UI +### 5. 一键启动 API 服务或 Web UI #### 5.1 启动命令 @@ -307,139 +360,12 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a 2. webui启动界面示例: - Web UI 对话界面: - ![img](img/webui_0813_0.png) -- Web UI 知识库管理页面: - ![](img/webui_0813_1.png) -### 6 分步启动 API 服务或 Web UI +![img](img/webui_0915_0.png) -注意:如使用了一键启动方式,可忽略本节。 - -#### 6.1 启动 LLM 服务 - -如需使用开源模型进行本地部署,需首先启动 LLM 服务,启动方式分为三种: - -- [基于多进程脚本 llm_api.py 启动 LLM 服务](README.md#5.1.1-基于多进程脚本-llm_api.py-启动-LLM-服务) -- [基于命令行脚本 llm_api_stale.py 启动 LLM 服务](README.md#5.1.2-基于命令行脚本-llm_api_stale.py-启动-LLM-服务) -- [PEFT 加载](README.md#5.1.3-PEFT-加载) - -三种方式只需选择一个即可,具体操作方式详见 5.1.1 - 5.1.3。 - -如果启动在线的API服务(如 OPENAI 的 API 接口),则无需启动 LLM 服务,即 5.1 小节的任何命令均无需启动。 - -##### 6.1.1 基于多进程脚本 llm_api.py 启动 LLM 服务 - -在项目根目录下,执行 [server/llm_api.py](server/llm_api.py) 脚本启动 **LLM 模型**服务: - -```shell -$ python server/llm_api.py -``` - -项目支持多卡加载,需在 llm_api.py 中的 create_model_worker_app 函数中,修改如下三个参数: - -```python -gpus=None, -num_gpus=1, -max_gpu_memory="20GiB" -``` - -其中,`gpus` 控制使用的显卡的ID,如果 "0,1"; - -`num_gpus` 控制使用的卡数; - -`max_gpu_memory` 控制每个卡使用的显存容量。 - -##### 6.1.2 基于命令行脚本 llm_api_stale.py 启动 LLM 服务 - -⚠️ **注意:** - -**1.llm_api_stale.py脚本原生仅适用于linux,mac设备需要安装对应的linux命令,win平台请使用wsl;** - -**2.加载非默认模型需要用命令行参数--model-path-address指定模型,不会读取model_config.py配置;** - -在项目根目录下,执行 [server/llm_api_stale.py](server/llm_api_stale.py) 脚本启动 **LLM 模型**服务: - -```shell -$ python server/llm_api_stale.py -``` - -该方式支持启动多个worker,示例启动方式: - -```shell -$ python server/llm_api_stale.py --model-path-address model1@host1@port1 model2@host2@port2 -``` - -如果出现server端口占用情况,需手动指定server端口,并同步修改model_config.py下对应模型的base_api_url为指定端口: - -```shell -$ python server/llm_api_stale.py --server-port 8887 -``` - -如果要启动多卡加载,示例命令如下: - -```shell -$ python server/llm_api_stale.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB -``` - -注:以如上方式启动LLM服务会以nohup命令在后台运行 FastChat 服务,如需停止服务,可以运行如下命令: - -```shell -$ python server/llm_api_shutdown.py --serve all -``` - -亦可单独停止一个 FastChat 服务模块,可选 [`all`, `controller`, `model_worker`, `openai_api_server`] - -##### 6.1.3 PEFT 加载(包括lora,p-tuning,prefix tuning, prompt tuning,ia3等) - -本项目基于 FastChat 加载 LLM 服务,故需以 FastChat 加载 PEFT 路径,即保证路径名称里必须有 peft 这个词,配置文件的名字为 adapter_config.json,peft 路径下包含 model.bin 格式的 PEFT 权重。 -详细步骤参考[加载lora微调后模型失效](https://github.com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822) - -![image](https://github.com/chatchat-space/Langchain-Chatchat/assets/22924096/4e056c1c-5c4b-4865-a1af-859cd58a625d) - -#### 6.2 启动 API 服务 - -本地部署情况下,按照 [5.1 节](README.md#5.1-启动-LLM-服务)**启动 LLM 服务后**,再执行 [server/api.py](server/api.py) 脚本启动 **API** 服务; - -在线调用API服务的情况下,直接执执行 [server/api.py](server/api.py) 脚本启动 **API** 服务; - -调用命令示例: - -```shell -$ python server/api.py -``` - -启动 API 服务后,可访问 `localhost:7861` 或 `{API 所在服务器 IP}:7861` FastAPI 自动生成的 docs 进行接口查看与测试。 - -- FastAPI docs 界面 - - ![](img/fastapi_docs_020_0.png) - -#### 6.3 启动 Web UI 服务 - -按照 [5.2 节](README.md#5.2-启动-API-服务)**启动 API 服务后**,执行 [webui.py](webui.py) 启动 **Web UI** 服务(默认使用端口 `8501`) - -```shell -$ streamlit run webui.py -``` - -使用 Langchain-Chatchat 主题色启动 **Web UI** 服务(默认使用端口 `8501`) - -```shell -$ streamlit run webui.py --theme.base "light" --theme.primaryColor "#165dff" --theme.secondaryBackgroundColor "#f5f5f5" --theme.textColor "#000000" -``` - -或使用以下命令指定启动 **Web UI** 服务并指定端口号 - -```shell -$ streamlit run webui.py --server.port 666 -``` - -- Web UI 对话界面: - - ![](img/webui_0813_0.png) - Web UI 知识库管理页面: - ![](img/webui_0813_1.png) +![](img/webui_0915_1.png) --- diff --git a/README_en.md b/README_en.md new file mode 100644 index 00000000..3e46f3eb --- /dev/null +++ b/README_en.md @@ -0,0 +1,363 @@ +![](img/logo-long-chatchat-trans-v2.png) + +**LangChain-Chatchat** (former Langchain-ChatGLM): A LLM application aims to implement knowledge- and search engineer- based QA based on Langchain and open-source or remote LLM api. + +## Content + +* Introduction +* Change Log +* Docker Deployment +* Deployment + + * Enviroment Preresiquisite + * Preparing Depolyment Enviroment + * Downloading model to local disk(for offline deployment only) + * Setting Configuration + * Knowledge Base Migration + * Luanching API Service or WebUI with One Command + * Luanching API Service or WebUI step-by-step +* FAQ +* Roadmap +* Wechat Group + +--- + +## Introduction + +🤖️ A Q&A application based on local knowledge base implemented using the idea of [langchain](https://github.com/hwchase17/langchain). The goal is to build a KBQA(Knowledge based Q&A) solution that is friendly to Chinese scenarios and open source models and can run both offline and online. + +💡 Inspried by [document.ai](https://github.com/GanymedeNil/document.ai) and [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) , we build a local knowledge base question answering application that can be implemented using an open source model or remote LLM api throughout the process. In the latest version of this project, [FastChat](https://github.com/lm-sys/FastChat) is used to access Vicuna, Alpaca, LLaMA, Koala, RWKV and many other models. Relying on [langchain](https:// github.com/langchain-ai/langchain) , this project supports calling services through the API provided based on [FastAPI](https://github.com/tiangolo/fastapi), or using the WebUI based on [Streamlit](https://github.com /streamlit/streamlit) . + +✅ Relying on the open source LLM and Embedding models, this project can realize full-process **offline private deployment**. At the same time, this project also supports the call of OpenAI GPT API- and Zhipu API, and will continue to expand the access to various models and remote APIs in the future. + +⛓️ The implementation principle of this project is shown in the graph below. The main process includes: loading files -> reading text -> text segmentation -> text vectorization -> question vectorization -> matching the `top-k` most similar to the question vector in the text vector -> The matched text is added to `prompt `as context and question -> submitted to `LLM` to generate an answer. + +📺[video introdution](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514) + +![实现原理图](img/langchain+chatglm.png) + +The main process analysis from the aspect of document process: + +![实现原理图2](img/langchain+chatglm2.png) + +🚩 The training or fined-tuning are not involved in the project, but still, one always can improve performance by do these. + +🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0) is supported, and in v7 the codes are update to v0.2.3. + +🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0) + +💻 Run Docker with one command: + +```shell +docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0 +``` + +--- + +## Change Log + +plese refer to [version change log](https://github.com/imClumsyPanda/langchain-ChatGLM/releases) + +### Current Features + +* **Consistent LLM Service based on FastChat**. The project use [FastChat](https://github.com/lm-sys/FastChat) to provide the API service of the open source LLM models and access it in the form of OpenAI API interface to improve the loading effect of the LLM model; +* **Chain and Agent based on Langchian**. Use the existing Chain implementation in [langchain](https://github.com/langchain-ai/langchain) to facilitate subsequent access to different types of Chain, and will test Agent access; +* **Full fuction API service based on FastAPI**. All interfaces can be tested in the docs automatically generated by [FastAPI](https://github.com/tiangolo/fastapi), and all dialogue interfaces support streaming or non-streaming output through parameters. ; +* **WebUI service based on Streamlit**. With [Streamlit](https://github.com/streamlit/streamlit), you can choose whether to start WebUI based on API services, add session management, customize session themes and switch, and will support different display of content forms of output in the future; +* **Abundant open source LLM and Embedding models**. The default LLM model in the project is changed to [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b), and the default Embedding model is changed to [moka-ai/m3e-base](https:// huggingface.co/moka-ai/m3e-base), the file loading method and the paragraph division method have also been adjusted. In the future, context expansion will be re-implemented and optional settings will be added; +* **Multiply vector libraries**. The project has expanded support for different types of vector libraries. Including [FAISS](https://github.com/facebookresearch/faiss), [Milvus](https://github.com/milvus -io/milvus), and [PGVector](https://github.com/pgvector/pgvector); +* **Varied Search engines**. We provide two search engines now: Bing and DuckDuckGo. DuckDuckGo search does not require configuring an API Key and can be used directly in environments with access to foreign services. + +## Supported Models + +The default LLM model in the project is changed to [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b), and the default Embedding model is changed to [moka-ai/m3e-base](https:// huggingface.co/moka-ai/m3e-base). + +### Supported LLM models + +The project use [FastChat](https://github.com/lm-sys/FastChat) to provide the API service of the open source LLM models, supported models include: + +- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) +- Vicuna, Alpaca, LLaMA, Koala +- [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven) +- [camel-ai/CAMEL-13B-Combined-Data](https://huggingface.co/camel-ai/CAMEL-13B-Combined-Data) +- [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b) +- [FreedomIntelligence/phoenix-inst-chat-7b](https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b) +- [h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b) +- [lcw99/polyglot-ko-12.8b-chang-instruct-chat](https://huggingface.co/lcw99/polyglot-ko-12.8b-chang-instruct-chat) +- [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5) +- [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat) +- [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT) +- [nomic-ai/gpt4all-13b-snoozy](https://huggingface.co/nomic-ai/gpt4all-13b-snoozy) +- [NousResearch/Nous-Hermes-13b](https://huggingface.co/NousResearch/Nous-Hermes-13b) +- [openaccess-ai-collective/manticore-13b-chat-pyg](https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg) +- [OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5](https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5) +- [project-baize/baize-v2-7b](https://huggingface.co/project-baize/baize-v2-7b) +- [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b) +- [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b) +- [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) +- [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) +- [tiiuae/falcon-40b](https://huggingface.co/tiiuae/falcon-40b) +- [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged) +- [togethercomputer/RedPajama-INCITE-7B-Chat](https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat) +- [WizardLM/WizardLM-13B-V1.0](https://huggingface.co/WizardLM/WizardLM-13B-V1.0) +- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0) +- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) +- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) +- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) +- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta) +- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and other models of FlagAlpha +- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) +- [all models of OpenOrca](https://huggingface.co/Open-Orca) +- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2) +- [VMware's OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct) + +* Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)(任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)) +* Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a model above. To activate, must have `peft` in the model path. Note: If loading multiple peft models, you can have them share the base model weights by setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model worker. + +Please refer to `llm_model_dict` in `configs.model_configs.py.example` to invoke OpenAI API. + +### Supported Embedding models + +Following models are tested by developers with Embedding class of [HuggingFace](https://huggingface.co/models?pipeline_tag=sentence-similarity): + +- [moka-ai/m3e-small](https://huggingface.co/moka-ai/m3e-small) +- [moka-ai/m3e-base](https://huggingface.co/moka-ai/m3e-base) +- [moka-ai/m3e-large](https://huggingface.co/moka-ai/m3e-large) +- [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh) +- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh) +- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) +- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct) +- [shibing624/text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence) +- [shibing624/text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase) +- [shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual) +- [shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese) +- [shibing624/text2vec-bge-large-chinese](https://huggingface.co/shibing624/text2vec-bge-large-chinese) +- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese) +- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh) +- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh) +- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings) + +--- + +## Docker image + +🐳 Docker image path: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0)` + +```shell +docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0 +``` + +- The image size of this version is `33.9GB`, using `v0.2.0`, with `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` as the base image +- This version has a built-in `embedding` model: `m3e-large`, built-in `chatglm2-6b-32k` +- This version is designed to facilitate one-click deployment. Please make sure you have installed the NVIDIA driver on your Linux distribution. +- Please note that you do not need to install the CUDA toolkit on the host system, but you need to install the `NVIDIA Driver` and the `NVIDIA Container Toolkit`, please refer to the [Installation Guide](https://docs.nvidia.com/datacenter/cloud -native/container-toolkit/latest/install-guide.html) +- It takes a certain amount of time to pull and start for the first time. When starting for the first time, please refer to the figure below to use `docker logs -f ` to view the log. +- If the startup process is stuck in the `Waiting..` step, it is recommended to use `docker exec -it bash` to enter the `/logs/` directory to view the corresponding stage logs + +--- + +## Deployment + +### Enviroment Preresiquisite + +The project is tested under Python3.8-python 3.10, CUDA 11.0-CUDA11.7, Windows, macOS of ARM architecture, and Linux platform. + +### 1. Preparing Depolyment Enviroment + +Please refer to [install.md](docs/INSTALL.md) + +### 2. Downloading model to local disk + +**For offline deployment only!** + +If you want to run this project in a local or offline environment, you need to first download the models required for the project to your local computer. Usually the open source LLM and Embedding models can be downloaded from [HuggingFace](https://huggingface.co/models). + +Take the LLM model [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) and Embedding model [moka-ai/m3e-base](https://huggingface. co/moka-ai/m3e-base) for example: + +To download the model, you need to [install Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), and then run: + +```Shell +$ git clone https://huggingface.co/THUDM/chatglm2-6b + +$ git clone https://huggingface.co/moka-ai/m3e-base +``` + +### 3. Setting Configuration + +Copy the model-related parameter configuration template file [configs/model_config.py.example](configs/model_config.py.example) and save it in the `./configs` path under the project path, and rename it to `model_config.py`. + +Copy the service-related parameter configuration template file [configs/server_config.py.example](configs/server_config.py.example) to save in the `./configs` path under the project path, and rename it to `server_config.py`. + +Before starting to execute Web UI or command line interaction, please check whether each model parameter in `configs/model_config.py` and `configs/server_config.py` meets the requirements. + +* Please confirm that the path to local LLM model and embedding model have been written in `llm_dict` of `configs/model_config.py`, here is an example: +* If you choose to use OpenAI's Embedding model, please write the model's ``key`` into `embedding_model_dict`. To use this model, you need to be able to access the OpenAI official API, or set up a proxy. + +```python +llm_model_dict={ + "chatglm2-6b": { + "local_model_path": "/Users/xxx/Downloads/chatglm2-6b", + "api_base_url": "http://localhost:8888/v1", # "name"修改为 FastChat 服务中的"api_base_url" + "api_key": "EMPTY" + }, + } +``` + +```python +embedding_model_dict = { + "m3e-base": "/Users/xxx/Downloads/m3e-base", + } +``` + +### 4. Knowledge Base Migration + +The knowledge base information is stored in the database, please initialize the database before running the project (we strongly recommend one back up the knowledge files before performing operations). + +- If you migrate from `0.1.x`, for the established knowledge base, please confirm that the vector library type and Embedding model of the knowledge base are consistent with the default settings in `configs/model_config.py`, if there is no change, simply add the existing repository information to the database with the following command: + + ```shell + $ python init_database.py + ``` +- If you are a beginner of the project whose knowledge base has not been established, or the knowledge base type and embedding model in the configuration file have changed, or the previous vector library did not enable `normalize_L2`, you need the following command to initialize or rebuild the knowledge base: + + ```shell + $ python init_database.py --recreate-vs + ``` + +### 5. Luanching API Service or WebUI with One Command + +#### 5.1 Command + +The script is `startuppy`, you can luanch all fastchat related, API,WebUI service with is, here is an example: + +```shell +$ python startup.py -a +``` + +optional args including: `-a(or --all-webui), --all-api, --llm-api, -c(or --controller),--openai-api, -m(or --model-worker), --api, --webui`, where: + +* `--all-webui` means to launch all related services of WEBUI +* `--all-api` means to launch all related services of API +* `--llm-api` means to launch all related services of FastChat +* `--openai-api` means to launch controller and openai-api-server of FastChat only +* `model-worker` means to launch model worker of FastChat only +* any other optional arg is to launch one particular function only + +#### 5.2 Launch none-default model + +If you want to specify a none-default model, use `--model-name` arg, here is a example: + +```shell +$ python startup.py --all-webui --model-name Qwen-7B-Chat +``` + +#### 5.3 Load model with multi-gpus + +If you want to load model with multi-gpus, then the following three parameters in `startup.create_model_worker_app` should be changed: + +```python +gpus=None, +num_gpus=1, +max_gpu_memory="20GiB" +``` + +where: + +* `gpus` is about specifying the gpus' ID, such as '0,1'; +* `num_gpus` is about specifying the number of gpus to be used under `gpus`; +* `max_gpu_memory` is about specifying the gpu memory of every gpu. + +note: + +* These parameters now can be specified by `server_config.FSCHST_MODEL_WORKERD`. +* In some extreme senses, `gpus` doesn't work, then one should specify the used gpus with environment variable `CUDA_VISIBLE_DEVICES`, here is an example: + +```shell +CUDA_VISIBLE_DEVICES=0,1 python startup.py -a +``` + +#### 5.4 Load PEFT + +Including lora,p-tuning,prefix tuning, prompt tuning,ia3 + +This project loads the LLM service based on FastChat, so one must load the PEFT in a FastChat way, that is, ensure that the word `peft` must be in the path name, the name of the configuration file must be `adapter_config.json`, and the path contains PEFT weights in `.bin` format. The peft path is specified in `args.model_names` of the `create_model_worker_app` function in `startup.py`, and enable the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` parameter. + +If the above method fails, you need to start standard fastchat service step by step. Step-by-step procedure could be found Section 6. For further steps, please refer to [Model invalid after loading lora fine-tuning](https://github. com/chatchat-space/Langchain-Chatchat/issues/1130#issuecomment-1685291822). + +#### **5.5 Some Notes** + +1. **The `startup.py` uses multi-process mode to start the services of each module, which may cause printing order problems. Please wait for all services to be initiated before calling, and call the service according to the default or specified port (default LLM API service port: `127.0.0.1:8888 `, default API service port:`127.0.0.1:7861 `, default WebUI service port: `127.0.0.1: 8501`)** +2. **The startup time of the service differs across devices, usually it takes 3-10 minutes. If it does not start for a long time, please go to the `./logs` directory to monitor the logs and locate the problem.** +3. **Using ctrl+C to exit on Linux may cause orphan processes due to the multi-process mechanism of Linux. You can exit through `shutdown_all.sh`** + +#### 5.6 Interface Examples + +The API, chat interface of WebUI, and knowledge management interface of WebUI are list below respectively. + +1. FastAPI docs + +![](img/fastapi_docs_020_0.png) + +2. Chat Interface of WebUI + +- Dialogue interface of WebUI + +![img](img/webui_0915_0.png) + +- Knowledge management interface of WebUI + +![img](img/webui_0915_1.png) + +### 6 Luanching API Service or WebUI step-by-step + +**The developers will depreciate step-by-step procudure in the future one or two version, feel free to ignore this part.** + +## FAQ + +Please refer to [FAQ](docs/FAQ.md) + +--- + +## Roadmap + +- [X] Langchain applications + + - [X] Load local documents + - [X] Unstructed documents + - [X] .md + - [X] .txt + - [X] .docx + - [ ] Structed documents + - [X] .csv + - [ ] .xlsx + - [ ] TextSplliter and Retriever + - [ ] multipy TextSplitter + - [ ] ChineseTextSplitter + - [ ] Recontructed Context Retriever + - [ ] Webpage + - [ ] SQL + - [ ] Knowledge Database + - [X] Search Engines + - [X] Bing + - [X] DuckDuckGo + - [ ] Agent +- [X] LLM Models + + - [X] [FastChat](https://github.com/lm-sys/fastchat) -based LLM Models + - [ ] Mutiply Remote LLM API +- [X] Embedding Models + + - [X] HuggingFace -based Embedding models + - [ ] Mutiply Remote Embedding API +- [X] 基于 FastAPI -based API +- [X] Web UI + + - [X] Streamlit -based Web UI + +--- + +## WeChat Group QR Code + +二维码 + +**WeChat Group** diff --git a/configs/__init__.py b/configs/__init__.py index d47abf14..41169e8b 100644 --- a/configs/__init__.py +++ b/configs/__init__.py @@ -1,4 +1,4 @@ from .model_config import * from .server_config import * -VERSION = "v0.2.4-preview" +VERSION = "v0.2.4" diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 39c51ca2..0aa0cf5b 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -5,6 +5,8 @@ LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(mes logger = logging.getLogger() logger.setLevel(logging.INFO) logging.basicConfig(format=LOG_FORMAT) +# 是否显示详细日志 +log_verbose = False # 在以下字典中修改属性值,以指定本地embedding模型存储位置 @@ -73,19 +75,43 @@ llm_model_dict = { # 比如: "openai_proxy": 'http://127.0.0.1:4780' "gpt-3.5-turbo": { "api_base_url": "https://api.openai.com/v1", - "api_key": os.environ.get("OPENAI_API_KEY"), - "openai_proxy": os.environ.get("OPENAI_PROXY") + "api_key": "", + "openai_proxy": "" }, # 线上模型。当前支持智谱AI。 # 如果没有设置有效的local_model_path,则认为是在线模型API。 # 请在server_config中为每个在线API设置不同的端口 # 具体注册及api key获取请前往 http://open.bigmodel.cn - "chatglm-api": { + "zhipu-api": { "api_base_url": "http://127.0.0.1:8888/v1", - "api_key": os.environ.get("ZHIPUAI_API_KEY"), + "api_key": "", "provider": "ChatGLMWorker", "version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro" }, + "minimax-api": { + "api_base_url": "http://127.0.0.1:8888/v1", + "group_id": "", + "api_key": "", + "is_pro": False, + "provider": "MiniMaxWorker", + }, + "xinghuo-api": { + "api_base_url": "http://127.0.0.1:8888/v1", + "APPID": "", + "APISecret": "", + "api_key": "", + "is_v2": False, + "provider": "XingHuoWorker", + }, + # 百度千帆 API,申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf + "qianfan-api": { + "version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo", 更多的见文档模型支持列表中千帆部分。 + "version_url": "", # 可以不填写version,直接填写在千帆申请模型发布的API地址 + "api_base_url": "http://127.0.0.1:8888/v1", + "api_key": "", + "secret_key": "", + "provider": "ErnieWorker", + } } # LLM 名称 @@ -94,9 +120,51 @@ LLM_MODEL = "chatglm2-6b" # 历史对话轮数 HISTORY_LEN = 3 +# LLM通用对话参数 +TEMPERATURE = 0.7 +# TOP_P = 0.95 # ChatOpenAI暂不支持该参数 + + # LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 LLM_DEVICE = "auto" +# TextSplitter + +text_splitter_dict = { + "ChineseRecursiveTextSplitter": { + "source": "", + "tokenizer_name_or_path": "", + }, + "SpacyTextSplitter": { + "source": "huggingface", + "tokenizer_name_or_path": "gpt2", + }, + "RecursiveCharacterTextSplitter": { + "source": "tiktoken", + "tokenizer_name_or_path": "cl100k_base", + }, + + "MarkdownHeaderTextSplitter": { + "headers_to_split_on": + [ + ("#", "head1"), + ("##", "head2"), + ("###", "head3"), + ("####", "head4"), + ] + }, +} + +# TEXT_SPLITTER 名称 +TEXT_SPLITTER = "ChineseRecursiveTextSplitter" + +# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) +CHUNK_SIZE = 250 + +# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) +OVERLAP_SIZE = 0 + + # 日志存储路径 LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") if not os.path.exists(LOG_PATH): @@ -104,12 +172,14 @@ if not os.path.exists(LOG_PATH): # 知识库默认存储路径 KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") - +if not os.path.exists(KB_ROOT_PATH): + os.mkdir(KB_ROOT_PATH) # 数据库默认存储路径。 # 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。 DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}" + # 可选向量库类型及对应配置 kbs_config = { "faiss": { @@ -132,20 +202,14 @@ DEFAULT_VS_TYPE = "faiss" # 缓存向量库数量 CACHED_VS_NUM = 1 -# 知识库中单段文本长度 -CHUNK_SIZE = 250 - -# 知识库中相邻文本重合长度 -OVERLAP_SIZE = 50 - # 知识库匹配向量数量 -VECTOR_SEARCH_TOP_K = 5 +VECTOR_SEARCH_TOP_K = 3 # 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 SCORE_THRESHOLD = 1 # 搜索引擎匹配结题数量 -SEARCH_ENGINE_TOP_K = 5 +SEARCH_ENGINE_TOP_K = 3 # nltk 模型存储路径 NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") diff --git a/configs/server_config.py.example b/configs/server_config.py.example index ad731e37..51f53dc3 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -59,13 +59,23 @@ FSCHAT_MODEL_WORKERS = { # "limit_worker_concurrency": 5, # "stream_interval": 2, # "no_register": False, + # "embed_in_truncate": False, }, "baichuan-7b": { # 使用default中的IP和端口 "device": "cpu", }, - "chatglm-api": { # 请为每个在线API设置不同的端口 + "zhipu-api": { # 请为每个在线API设置不同的端口 "port": 20003, }, + "minimax-api": { # 请为每个在线API设置不同的端口 + "port": 20004, + }, + "xinghuo-api": { # 请为每个在线API设置不同的端口 + "port": 20005, + }, + "qianfan-api": { + "port": 20006, + }, } # fastchat multi model worker server diff --git a/docs/splitter.md b/docs/splitter.md new file mode 100644 index 00000000..5f0e1078 --- /dev/null +++ b/docs/splitter.md @@ -0,0 +1,24 @@ +## 如何自定义分词器 + +### 在哪里写,哪些文件要改 +1. 在```text_splitter```文件夹下新建一个文件,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器,如下所示: +```python +from .my_splitter import MySplitter +``` + +2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示: +```python +MySplitter: { + "source": "huggingface", ## 选择tiktoken则使用openai的方法 + "tokenizer_name_or_path": "your tokenizer", #如果选择huggingface则使用huggingface的方法,部分tokenizer需要从Huggingface下载 + } +TEXT_SPLITTER = "MySplitter" +``` + +完成上述步骤后,就能使用自己的分词器了。 + +### 如何贡献您的分词器 + +1. 将您的分词器所在的代码文件放在```text_splitter```文件夹下,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器。 +2. 发起PR,并说明您的分词器面向的场景或者改进之处。我们非常期待您能举例一个具体的应用场景。 +3. 在Readme.md中添加您的分词器的使用方法和支持说明。 diff --git a/document_loaders/mypdfloader.py b/document_loaders/mypdfloader.py index a3153a86..6cb77267 100644 --- a/document_loaders/mypdfloader.py +++ b/document_loaders/mypdfloader.py @@ -1,5 +1,6 @@ from typing import List from langchain.document_loaders.unstructured import UnstructuredFileLoader +import tqdm class RapidOCRPDFLoader(UnstructuredFileLoader): @@ -11,7 +12,14 @@ class RapidOCRPDFLoader(UnstructuredFileLoader): ocr = RapidOCR() doc = fitz.open(filepath) resp = "" - for page in doc: + + b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0") + for i, page in enumerate(doc): + + # 更新描述 + b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i)) + # 立即显示进度条更新结果 + b_unit.refresh() # TODO: 依据文本与图片顺序调整处理方式 text = page.get_text("") resp += text + "\n" @@ -24,6 +32,9 @@ class RapidOCRPDFLoader(UnstructuredFileLoader): if result: ocr_result = [line[1] for line in result] resp += "\n".join(ocr_result) + + # 更新进度 + b_unit.update(1) return resp text = pdf2text(self.file_path) diff --git a/image/README/1694251762513.png b/image/README/1694251762513.png new file mode 100644 index 00000000..a16a6081 Binary files /dev/null and b/image/README/1694251762513.png differ diff --git a/image/README_en/1694251973694.png b/image/README_en/1694251973694.png new file mode 100644 index 00000000..1a6d909f Binary files /dev/null and b/image/README_en/1694251973694.png differ diff --git a/image/README_en/1694252029167.png b/image/README_en/1694252029167.png new file mode 100644 index 00000000..a16a6081 Binary files /dev/null and b/image/README_en/1694252029167.png differ diff --git a/img/chatchat-qrcode.jpg b/img/chatchat-qrcode.jpg new file mode 100644 index 00000000..a16a6081 Binary files /dev/null and b/img/chatchat-qrcode.jpg differ diff --git a/img/webui_020_0.png b/img/webui_020_0.png deleted file mode 100644 index bd07527b..00000000 Binary files a/img/webui_020_0.png and /dev/null differ diff --git a/img/webui_020_1.png b/img/webui_020_1.png deleted file mode 100644 index 1818c5dd..00000000 Binary files a/img/webui_020_1.png and /dev/null differ diff --git a/img/webui_0813_0.png b/img/webui_0813_0.png deleted file mode 100644 index 52022348..00000000 Binary files a/img/webui_0813_0.png and /dev/null differ diff --git a/img/webui_0813_1.png b/img/webui_0813_1.png deleted file mode 100644 index c0647272..00000000 Binary files a/img/webui_0813_1.png and /dev/null differ diff --git a/img/webui_0915_0.png b/img/webui_0915_0.png new file mode 100644 index 00000000..058d7b17 Binary files /dev/null and b/img/webui_0915_0.png differ diff --git a/img/webui_0915_1.png b/img/webui_0915_1.png new file mode 100644 index 00000000..8df1eca9 Binary files /dev/null and b/img/webui_0915_1.png differ diff --git a/knowledge_base/samples/vector_store/index.faiss b/knowledge_base/samples/vector_store/index.faiss index e20f2cff..2404c993 100644 Binary files a/knowledge_base/samples/vector_store/index.faiss and b/knowledge_base/samples/vector_store/index.faiss differ diff --git a/knowledge_base/samples/vector_store/index.pkl b/knowledge_base/samples/vector_store/index.pkl index 7cbf077e..709f9ee7 100644 Binary files a/knowledge_base/samples/vector_store/index.pkl and b/knowledge_base/samples/vector_store/index.pkl differ diff --git a/requirements.txt b/requirements.txt index 910a9ed6..57147881 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ -langchain==0.0.266 +langchain==0.0.287 +fschat[model_worker]==0.2.28 openai -zhipuai sentence_transformers -fschat==0.2.24 transformers>=4.31.0 torch~=2.0.0 fastapi~=0.99.1 @@ -19,6 +18,12 @@ spacy PyMuPDF==1.22.5 rapidocr_onnxruntime>=1.3.2 +requests +pathlib +pytest +scikit-learn +numexpr + # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 # psycopg2 @@ -33,3 +38,5 @@ streamlit-chatbox>=1.1.6 streamlit-aggrid>=0.3.4.post3 httpx~=0.24.1 watchdog +tqdm +websockets diff --git a/requirements_api.txt b/requirements_api.txt index bdecf3c7..c56c07bf 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,7 +1,7 @@ -langchain==0.0.266 +langchain==0.0.287 +fschat[model_worker]==0.2.28 openai sentence_transformers -fschat==0.2.24 transformers>=4.31.0 torch~=2.0.0 fastapi~=0.99.1 @@ -9,17 +9,22 @@ nltk~=3.8.1 uvicorn~=0.23.1 starlette~=0.27.0 pydantic~=1.10.11 -unstructured[all-docs] +unstructured[all-docs]>=0.10.4 python-magic-bin; sys_platform == 'win32' SQLAlchemy==2.0.19 faiss-cpu -nltk accelerate spacy PyMuPDF==1.22.5 -rapidocr_onnxruntime>=1.3.1 +rapidocr_onnxruntime>=1.3.2 + +requests +pathlib +pytest +scikit-learn +numexpr # uncomment libs if you want to use corresponding vector store # pymilvus==2.1.3 # requires milvus==2.1.3 # psycopg2 -# pgvector +# pgvector \ No newline at end of file diff --git a/requirements_webui.txt b/requirements_webui.txt index da66c307..8d49ae02 100644 --- a/requirements_webui.txt +++ b/requirements_webui.txt @@ -7,4 +7,5 @@ streamlit-chatbox>=1.1.6 streamlit-aggrid>=0.3.4.post3 httpx~=0.24.1 nltk -watchdog \ No newline at end of file +watchdog +websockets diff --git a/server/api.py b/server/api.py index 37954b7f..357a0678 100644 --- a/server/api.py +++ b/server/api.py @@ -4,22 +4,21 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import LLM_MODEL, NLTK_DATA_PATH -from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT from configs import VERSION +from configs.model_config import NLTK_DATA_PATH +from configs.server_config import OPEN_CROSS_DOMAIN import argparse import uvicorn -from fastapi import Body from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb -from server.knowledge_base.kb_doc_api import (list_files, upload_doc, delete_doc, - update_doc, download_doc, recreate_vector_store, +from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, + update_docs, download_doc, recreate_vector_store, search_docs, DocumentWithScore) -from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address -import httpx +from server.llm_api import list_llm_models, change_llm_model, stop_llm_model +from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline from typing import List nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -98,23 +97,23 @@ def create_app(): summary="搜索知识库" )(search_docs) - app.post("/knowledge_base/upload_doc", + app.post("/knowledge_base/upload_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, - summary="上传文件到知识库" - )(upload_doc) + summary="上传文件到知识库,并/或进行向量化" + )(upload_docs) - app.post("/knowledge_base/delete_doc", + app.post("/knowledge_base/delete_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, summary="删除知识库内指定文件" - )(delete_doc) + )(delete_docs) - app.post("/knowledge_base/update_doc", + app.post("/knowledge_base/update_docs", tags=["Knowledge Base Management"], response_model=BaseResponse, summary="更新现有文件到知识库" - )(update_doc) + )(update_docs) app.get("/knowledge_base/download_doc", tags=["Knowledge Base Management"], @@ -126,73 +125,20 @@ def create_app(): )(recreate_vector_store) # LLM模型相关接口 - @app.post("/llm_model/list_models", + app.post("/llm_model/list_models", tags=["LLM Model Management"], - summary="列出当前已加载的模型") - def list_models( - controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) - ) -> BaseResponse: - ''' - 从fastchat controller获取已加载模型列表 - ''' - try: - controller_address = controller_address or fschat_controller_address() - r = httpx.post(controller_address + "/list_models") - return BaseResponse(data=r.json()["models"]) - except Exception as e: - return BaseResponse( - code=500, - data=[], - msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") + summary="列出当前已加载的模型", + )(list_llm_models) - @app.post("/llm_model/stop", + app.post("/llm_model/stop", tags=["LLM Model Management"], summary="停止指定的LLM模型(Model Worker)", - ) - def stop_llm_model( - model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]), - controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) - ) -> BaseResponse: - ''' - 向fastchat controller请求停止某个LLM模型。 - 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 - ''' - try: - controller_address = controller_address or fschat_controller_address() - r = httpx.post( - controller_address + "/release_worker", - json={"model_name": model_name}, - ) - return r.json() - except Exception as e: - return BaseResponse( - code=500, - msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}") + )(stop_llm_model) - @app.post("/llm_model/change", + app.post("/llm_model/change", tags=["LLM Model Management"], summary="切换指定的LLM模型(Model Worker)", - ) - def change_llm_model( - model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]), - new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]), - controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) - ): - ''' - 向fastchat controller请求切换LLM模型。 - ''' - try: - controller_address = controller_address or fschat_controller_address() - r = httpx.post( - controller_address + "/release_worker", - json={"model_name": model_name, "new_model_name": new_model_name}, - timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model - ) - return r.json() - except Exception as e: - return BaseResponse( - code=500, - msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") + )(change_llm_model) return app diff --git a/server/chat/chat.py b/server/chat/chat.py index ba23a5a1..c025c3c2 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,6 +1,6 @@ from fastapi import Body from fastapi.responses import StreamingResponse -from configs.model_config import llm_model_dict, LLM_MODEL +from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE from server.chat.utils import wrap_done from langchain.chat_models import ChatOpenAI from langchain import LLMChain @@ -12,15 +12,17 @@ from typing import List from server.chat.utils import History -def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), - history: List[History] = Body([], +async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), + history: List[History] = Body([], description="历史对话", examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑"}]] ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), + # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), ): history = [History.from_data(h) for h in history] @@ -37,6 +39,7 @@ def chat(query: str = Body(..., description="用户输入", examples=["恼羞成 openai_api_key=llm_model_dict[model_name]["api_key"], openai_api_base=llm_model_dict[model_name]["api_base_url"], model_name=model_name, + temperature=temperature, openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 69ec25dd..b26f2cbc 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,7 +1,8 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, - VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + TEMPERATURE) from server.chat.utils import wrap_done from server.utils import BaseResponse from langchain.chat_models import ChatOpenAI @@ -18,11 +19,11 @@ from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs -def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), - history: List[History] = Body([], +async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), + history: List[History] = Body([], description="历史对话", examples=[[ {"role": "user", @@ -30,10 +31,11 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp {"role": "assistant", "content": "虎头虎脑"}]] ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), - local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), - request: Request = None, + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), + local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), + request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: @@ -48,6 +50,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + model = ChatOpenAI( streaming=True, verbose=True, @@ -55,6 +58,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp openai_api_key=llm_model_dict[model_name]["api_key"], openai_api_base=llm_model_dict[model_name]["api_base_url"], model_name=model_name, + temperature=temperature, openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) docs = search_docs(query, knowledge_base_name, top_k, score_threshold) @@ -86,9 +90,8 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response - yield json.dumps({"answer": token, - "docs": source_documents}, - ensure_ascii=False) + yield json.dumps({"answer": token}, ensure_ascii=False) + yield json.dumps({"docs": source_documents}, ensure_ascii=False) else: answer = "" async for token in callback.aiter(): diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index a799c623..857ac979 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -1,7 +1,7 @@ from fastapi.responses import StreamingResponse from typing import List import openai -from configs.model_config import llm_model_dict, LLM_MODEL, logger +from configs.model_config import llm_model_dict, LLM_MODEL, logger, log_verbose from pydantic import BaseModel @@ -29,13 +29,13 @@ async def openai_chat(msg: OpenAiChatMsgIn): print(f"{openai.api_base=}") print(msg) - def get_response(msg): + async def get_response(msg): data = msg.dict() try: - response = openai.ChatCompletion.create(**data) + response = await openai.ChatCompletion.acreate(**data) if msg.stream: - for data in response: + async for data in response: if choices := data.choices: if chunk := choices[0].get("delta", {}).get("content"): print(chunk, end="", flush=True) @@ -46,8 +46,9 @@ async def openai_chat(msg: OpenAiChatMsgIn): print(answer) yield(answer) except Exception as e: - print(type(e)) - logger.error(e) + msg = f"获取ChatCompletion时出错:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) return StreamingResponse( get_response(msg), diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 8fe7dae6..f8e4ebe9 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -2,7 +2,9 @@ from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY from fastapi import Body from fastapi.responses import StreamingResponse -from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) +from fastapi.concurrency import run_in_threadpool +from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, + PROMPT_TEMPLATE, TEMPERATURE) from server.chat.utils import wrap_done from server.utils import BaseResponse from langchain.chat_models import ChatOpenAI @@ -47,29 +49,31 @@ def search_result2docs(search_results): return docs -def lookup_search_engine( +async def lookup_search_engine( query: str, search_engine_name: str, top_k: int = SEARCH_ENGINE_TOP_K, ): - results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k) + search_engine = SEARCH_ENGINES[search_engine_name] + results = await run_in_threadpool(search_engine, query, result_len=top_k) docs = search_result2docs(results) return docs -def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]), - search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), - top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), - history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), +async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]), + search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), + top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), + history: List[History] = Body([], + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -93,10 +97,11 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl openai_api_key=llm_model_dict[model_name]["api_key"], openai_api_base=llm_model_dict[model_name]["api_base_url"], model_name=model_name, + temperature=temperature, openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) - docs = lookup_search_engine(query, search_engine_name, top_k) + docs = await lookup_search_engine(query, search_engine_name, top_k) context = "\n".join([doc.page_content for doc in docs]) input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) @@ -119,9 +124,8 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response - yield json.dumps({"answer": token, - "docs": source_documents}, - ensure_ascii=False) + yield json.dumps({"answer": token}, ensure_ascii=False) + yield json.dumps({"docs": source_documents}, ensure_ascii=False) else: answer = "" async for token in callback.aiter(): diff --git a/server/chat/utils.py b/server/chat/utils.py index 2167f10e..a80648b1 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -2,6 +2,7 @@ import asyncio from typing import Awaitable, List, Tuple, Dict, Union from pydantic import BaseModel, Field from langchain.prompts.chat import ChatMessagePromptTemplate +from configs import logger, log_verbose async def wrap_done(fn: Awaitable, event: asyncio.Event): @@ -10,7 +11,9 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): await fn except Exception as e: # TODO: handle exception - print(f"Caught exception: {e}") + msg = f"Caught exception: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) finally: # Signal the aiter to stop. event.set() diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index b9151b89..c7b703e9 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -3,19 +3,19 @@ from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_base_repository import list_kbs_from_db -from configs.model_config import EMBEDDING_MODEL +from configs.model_config import EMBEDDING_MODEL, logger, log_verbose from fastapi import Body -async def list_kbs(): +def list_kbs(): # Get List of Knowledge Base return ListResponse(data=list_kbs_from_db()) -async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), - vector_store_type: str = Body("faiss"), - embed_model: str = Body(EMBEDDING_MODEL), - ) -> BaseResponse: +def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), + vector_store_type: str = Body("faiss"), + embed_model: str = Body(EMBEDDING_MODEL), + ) -> BaseResponse: # Create selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -30,14 +30,16 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), try: kb.create_kb() except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"创建知识库出错: {e}") + msg = f"创建知识库出错: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + return BaseResponse(code=500, msg=msg) return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") -async def delete_kb( - knowledge_base_name: str = Body(..., examples=["samples"]) +def delete_kb( + knowledge_base_name: str = Body(..., examples=["samples"]) ) -> BaseResponse: # Delete selected knowledge base if not validate_kb_name(knowledge_base_name): @@ -55,7 +57,9 @@ async def delete_kb( if status: return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"删除知识库时出现意外: {e}") + msg = f"删除知识库时出现意外: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + return BaseResponse(code=500, msg=msg) return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py new file mode 100644 index 00000000..f3e6d654 --- /dev/null +++ b/server/knowledge_base/kb_cache/base.py @@ -0,0 +1,137 @@ +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.embeddings import HuggingFaceBgeEmbeddings +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +import threading +from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE, + embedding_model_dict, logger, log_verbose) +from server.utils import embedding_device +from contextlib import contextmanager +from collections import OrderedDict +from typing import List, Any, Union, Tuple + + +class ThreadSafeObject: + def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None): + self._obj = obj + self._key = key + self._pool = pool + self._lock = threading.RLock() + self._loaded = threading.Event() + + def __repr__(self) -> str: + cls = type(self).__name__ + return f"<{cls}: key: {self._key}, obj: {self._obj}>" + + @contextmanager + def acquire(self, owner: str = "", msg: str = ""): + owner = owner or f"thread {threading.get_native_id()}" + try: + self._lock.acquire() + if self._pool is not None: + self._pool._cache.move_to_end(self._key) + if log_verbose: + logger.info(f"{owner} 开始操作:{self._key}。{msg}") + yield self._obj + finally: + if log_verbose: + logger.info(f"{owner} 结束操作:{self._key}。{msg}") + self._lock.release() + + def start_loading(self): + self._loaded.clear() + + def finish_loading(self): + self._loaded.set() + + def wait_for_loading(self): + self._loaded.wait() + + @property + def obj(self): + return self._obj + + @obj.setter + def obj(self, val: Any): + self._obj = val + + +class CachePool: + def __init__(self, cache_num: int = -1): + self._cache_num = cache_num + self._cache = OrderedDict() + self.atomic = threading.RLock() + + def keys(self) -> List[str]: + return list(self._cache.keys()) + + def _check_count(self): + if isinstance(self._cache_num, int) and self._cache_num > 0: + while len(self._cache) > self._cache_num: + self._cache.popitem(last=False) + + def get(self, key: str) -> ThreadSafeObject: + if cache := self._cache.get(key): + cache.wait_for_loading() + return cache + + def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject: + self._cache[key] = obj + self._check_count() + return obj + + def pop(self, key: str = None) -> ThreadSafeObject: + if key is None: + return self._cache.popitem(last=False) + else: + return self._cache.pop(key, None) + + def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""): + cache = self.get(key) + if cache is None: + raise RuntimeError(f"请求的资源 {key} 不存在") + elif isinstance(cache, ThreadSafeObject): + self._cache.move_to_end(key) + return cache.acquire(owner=owner, msg=msg) + else: + return cache + + def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings: + from server.db.repository.knowledge_base_repository import get_kb_detail + + kb_detail = get_kb_detail(kb_name=kb_name) + print(kb_detail) + embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL) + return embeddings_pool.load_embeddings(model=embed_model, device=embed_device) + + +class EmbeddingsPool(CachePool): + def load_embeddings(self, model: str, device: str) -> Embeddings: + self.atomic.acquire() + model = model or EMBEDDING_MODEL + device = device or embedding_device() + key = (model, device) + if not self.get(key): + item = ThreadSafeObject(key, pool=self) + self.set(key, item) + with item.acquire(msg="初始化"): + self.atomic.release() + if model == "text-embedding-ada-002": # openai text-embedding-ada-002 + embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) + elif 'bge-' in model: + embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}, + query_instruction="为这个句子生成表示以用于检索相关文章:") + if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding + embeddings.query_instruction = "" + else: + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) + item.obj = embeddings + item.finish_loading() + else: + self.atomic.release() + return self.get(key).obj + + +embeddings_pool = EmbeddingsPool(cache_num=1) diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py new file mode 100644 index 00000000..325c7bb1 --- /dev/null +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -0,0 +1,157 @@ +from server.knowledge_base.kb_cache.base import * +from server.knowledge_base.utils import get_vs_path +from langchain.vectorstores import FAISS +import os + + +class ThreadSafeFaiss(ThreadSafeObject): + def __repr__(self) -> str: + cls = type(self).__name__ + return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>" + + def docs_count(self) -> int: + return len(self._obj.docstore._dict) + + def save(self, path: str, create_path: bool = True): + with self.acquire(): + if not os.path.isdir(path) and create_path: + os.makedirs(path) + ret = self._obj.save_local(path) + logger.info(f"已将向量库 {self._key} 保存到磁盘") + return ret + + def clear(self): + ret = [] + with self.acquire(): + ids = list(self._obj.docstore._dict.keys()) + if ids: + ret = self._obj.delete(ids) + assert len(self._obj.docstore._dict) == 0 + logger.info(f"已将向量库 {self._key} 清空") + return ret + + +class _FaissPool(CachePool): + def new_vector_store( + self, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), + ) -> FAISS: + embeddings = embeddings_pool.load_embeddings(embed_model, embed_device) + + # create an empty vector store + doc = Document(page_content="init", metadata={}) + vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) + ids = list(vector_store.docstore._dict.keys()) + vector_store.delete(ids) + return vector_store + + def save_vector_store(self, kb_name: str, path: str=None): + if cache := self.get(kb_name): + return cache.save(path) + + def unload_vector_store(self, kb_name: str): + if cache := self.get(kb_name): + self.pop(kb_name) + logger.info(f"成功释放向量库:{kb_name}") + + +class KBFaissPool(_FaissPool): + def load_vector_store( + self, + kb_name: str, + create: bool = True, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), + ) -> ThreadSafeFaiss: + self.atomic.acquire() + cache = self.get(kb_name) + if cache is None: + item = ThreadSafeFaiss(kb_name, pool=self) + self.set(kb_name, item) + with item.acquire(msg="初始化"): + self.atomic.release() + logger.info(f"loading vector store in '{kb_name}' from disk.") + vs_path = get_vs_path(kb_name) + + if os.path.isfile(os.path.join(vs_path, "index.faiss")): + embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device) + vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + elif create: + # create an empty vector store + if not os.path.exists(vs_path): + os.makedirs(vs_path) + vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device) + vector_store.save_local(vs_path) + else: + raise RuntimeError(f"knowledge base {kb_name} not exist.") + item.obj = vector_store + item.finish_loading() + else: + self.atomic.release() + return self.get(kb_name) + + +class MemoFaissPool(_FaissPool): + def load_vector_store( + self, + kb_name: str, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), + ) -> ThreadSafeFaiss: + self.atomic.acquire() + cache = self.get(kb_name) + if cache is None: + item = ThreadSafeFaiss(kb_name, pool=self) + self.set(kb_name, item) + with item.acquire(msg="初始化"): + self.atomic.release() + logger.info(f"loading vector store in '{kb_name}' to memory.") + # create an empty vector store + vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device) + item.obj = vector_store + item.finish_loading() + else: + self.atomic.release() + return self.get(kb_name) + + +kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM) +memo_faiss_pool = MemoFaissPool() + + +if __name__ == "__main__": + import time, random + from pprint import pprint + + kb_names = ["vs1", "vs2", "vs3"] + # for name in kb_names: + # memo_faiss_pool.load_vector_store(name) + + def worker(vs_name: str, name: str): + vs_name = "samples" + time.sleep(random.randint(1, 5)) + embeddings = embeddings_pool.load_embeddings() + r = random.randint(1, 3) + + with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs: + if r == 1: # add docs + ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings) + pprint(ids) + elif r == 2: # search docs + docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0) + pprint(docs) + if r == 3: # delete docs + logger.warning(f"清除 {vs_name} by {name}") + kb_faiss_pool.get(vs_name).clear() + + threads = [] + for n in range(1, 30): + t = threading.Thread(target=worker, + kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"}, + daemon=True) + t.start() + threads.append(t) + + for t in threads: + t.join() diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 7ea5d271..02ad222c 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,12 +1,18 @@ import os import urllib from fastapi import File, Form, Body, Query, UploadFile -from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) -from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import validate_kb_name, list_files_from_folder, KnowledgeFile +from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, + logger, log_verbose,) +from server.utils import BaseResponse, ListResponse, run_in_thread_pool +from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path, + files2docs_in_thread, KnowledgeFile) from fastapi.responses import StreamingResponse, FileResponse +from pydantic import Json import json from server.knowledge_base.kb_service.base import KBServiceFactory +from server.db.repository.knowledge_file_repository import get_file_detail from typing import List, Dict from langchain.docstore.document import Document @@ -29,7 +35,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" return data -async def list_files( +def list_files( knowledge_base_name: str ) -> ListResponse: if not validate_kb_name(knowledge_base_name): @@ -44,11 +50,87 @@ async def list_files( return ListResponse(data=all_doc_names) -async def upload_doc(file: UploadFile = File(..., description="上传文件"), - knowledge_base_name: str = Form(..., description="知识库名称", examples=["kb1"]), - override: bool = Form(False, description="覆盖已有文件"), - not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), - ) -> BaseResponse: +def _save_files_in_thread(files: List[UploadFile], + knowledge_base_name: str, + override: bool): + ''' + 通过多线程将上传的文件保存到对应知识库目录内。 + 生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} + ''' + def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict: + ''' + 保存单个文件。 + ''' + try: + filename = file.filename + file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename) + data = {"knowledge_base_name": knowledge_base_name, "file_name": filename} + + file_content = file.file.read() # 读取上传文件的内容 + if (os.path.isfile(file_path) + and not override + and os.path.getsize(file_path) == len(file_content) + ): + # TODO: filesize 不同后的处理 + file_status = f"文件 {filename} 已存在。" + logger.warn(file_status) + return dict(code=404, msg=file_status, data=data) + + with open(file_path, "wb") as f: + f.write(file_content) + return dict(code=200, msg=f"成功上传文件 {filename}", data=data) + except Exception as e: + msg = f"{filename} 文件上传失败,报错信息为: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + return dict(code=500, msg=msg, data=data) + + params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files] + for result in run_in_thread_pool(save_file, params=params): + yield result + + +# 似乎没有单独增加一个文件上传API接口的必要 +# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), +# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), +# override: bool = Form(False, description="覆盖已有文件")): +# ''' +# API接口:上传文件。流式返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} +# ''' +# def generate(files, knowledge_base_name, override): +# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): +# yield json.dumps(result, ensure_ascii=False) + +# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream") + + +# TODO: 等langchain.document_loaders支持内存文件的时候再开通 +# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), +# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), +# override: bool = Form(False, description="覆盖已有文件"), +# save: bool = Form(True, description="是否将文件保存到知识库目录")): +# def save_files(files, knowledge_base_name, override): +# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): +# yield json.dumps(result, ensure_ascii=False) + +# def files_to_docs(files): +# for result in files2docs_in_thread(files): +# yield json.dumps(result, ensure_ascii=False) + + +def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), + override: bool = Form(False, description="覆盖已有文件"), + to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), + chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), + not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), + ) -> BaseResponse: + ''' + API接口:上传文件,并/或向量化 + ''' if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -56,40 +138,42 @@ async def upload_doc(file: UploadFile = File(..., description="上传文件"), if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - file_content = await file.read() # 读取上传文件的内容 + failed_files = {} + file_names = list(docs.keys()) - try: - kb_file = KnowledgeFile(filename=file.filename, - knowledge_base_name=knowledge_base_name) + # 先将上传的文件保存到磁盘 + for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): + filename = result["data"]["file_name"] + if result["code"] != 200: + failed_files[filename] = result["msg"] - if (os.path.exists(kb_file.filepath) - and not override - and os.path.getsize(kb_file.filepath) == len(file_content) - ): - # TODO: filesize 不同后的处理 - file_status = f"文件 {kb_file.filename} 已存在。" - return BaseResponse(code=404, msg=file_status) + if filename not in file_names: + file_names.append(filename) - with open(kb_file.filepath, "wb") as f: - f.write(file_content) - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") + # 对保存的文件进行向量化 + if to_vector_store: + result = update_docs( + knowledge_base_name=knowledge_base_name, + file_names=file_names, + override_custom_docs=True, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + docs=docs, + not_refresh_vs_cache=True, + ) + failed_files.update(result.data["failed_files"]) + if not not_refresh_vs_cache: + kb.save_vector_store() - try: - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件向量化失败,报错信息为: {e}") - - return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") + return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files}) -async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), - doc_name: str = Body(..., examples=["file_name.md"]), - delete_content: bool = Body(False), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), - ) -> BaseResponse: +def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]), + file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), + delete_content: bool = Body(False), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), + ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -98,24 +182,36 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - if not kb.exist_doc(doc_name): - return BaseResponse(code=404, msg=f"未找到文件 {doc_name}") + failed_files = {} + for file_name in file_names: + if not kb.exist_doc(file_name): + failed_files[file_name] = f"未找到文件 {file_name}" - try: - kb_file = KnowledgeFile(filename=doc_name, - knowledge_base_name=knowledge_base_name) - kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=not_refresh_vs_cache) - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件删除失败,错误信息:{e}") + try: + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True) + except Exception as e: + msg = f"{file_name} 文件删除失败,错误信息:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + failed_files[file_name] = msg - return BaseResponse(code=200, msg=f"{kb_file.filename} 文件删除成功") + if not not_refresh_vs_cache: + kb.save_vector_store() + + return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files}) -async def update_doc( - knowledge_base_name: str = Body(..., examples=["samples"]), - file_name: str = Body(..., examples=["file_name"]), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), +def update_docs( + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]), + chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), + docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: ''' 更新知识库文档 @@ -127,22 +223,62 @@ async def update_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") - try: - kb_file = KnowledgeFile(filename=file_name, - knowledge_base_name=knowledge_base_name) - if os.path.exists(kb_file.filepath): - kb.update_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") - except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败,错误信息是:{e}") + failed_files = {} + kb_files = [] - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") + # 生成需要加载docs的文件列表 + for file_name in file_names: + file_detail= get_file_detail(kb_name=knowledge_base_name, filename=file_name) + # 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖 + if file_detail.get("custom_docs") and not override_custom_docs: + continue + if file_name not in docs: + try: + kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)) + except Exception as e: + msg = f"加载文档 {file_name} 时出错:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + failed_files[file_name] = msg + + # 从文件生成docs,并进行向量化。 + # 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile + for status, result in files2docs_in_thread(kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance): + if status: + kb_name, file_name, new_docs = result + kb_file = KnowledgeFile(filename=file_name, + knowledge_base_name=knowledge_base_name) + kb_file.splited_docs = new_docs + kb.update_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + failed_files[file_name] = error + + # 将自定义的docs进行向量化 + for file_name, v in docs.items(): + try: + v = [x if isinstance(x, Document) else Document(**x) for x in v] + kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) + kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True) + except Exception as e: + msg = f"为 {file_name} 添加自定义docs时出错:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + failed_files[file_name] = msg + + if not not_refresh_vs_cache: + kb.save_vector_store() + + return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files}) -async def download_doc( - knowledge_base_name: str = Query(..., examples=["samples"]), - file_name: str = Query(..., examples=["test.txt"]), +def download_doc( + knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]), + file_name: str = Query(...,description="文件名称", examples=["test.txt"]), + preview: bool = Query(False, description="是:浏览器内预览;否:下载"), ): ''' 下载知识库文档 @@ -154,6 +290,11 @@ async def download_doc( if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + if preview: + content_disposition_type = "inline" + else: + content_disposition_type = None + try: kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name) @@ -162,20 +303,27 @@ async def download_doc( return FileResponse( path=kb_file.filepath, filename=kb_file.filename, - media_type="multipart/form-data") + media_type="multipart/form-data", + content_disposition_type=content_disposition_type, + ) except Exception as e: - print(e) - return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败,错误信息是:{e}") + msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + return BaseResponse(code=500, msg=msg) return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") -async def recreate_vector_store( - knowledge_base_name: str = Body(..., examples=["samples"]), - allow_empty_kb: bool = Body(True), - vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(EMBEDDING_MODEL), - ): +def recreate_vector_store( + knowledge_base_name: str = Body(..., examples=["samples"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), + chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), +): ''' recreate vector store from the content. this is usefull when user can copy files to content folder directly instead of upload through network. @@ -190,27 +338,33 @@ async def recreate_vector_store( else: kb.create_kb() kb.clear_vs() - docs = list_files_from_folder(knowledge_base_name) - for i, doc in enumerate(docs): - try: - kb_file = KnowledgeFile(doc, knowledge_base_name) + files = list_files_from_folder(knowledge_base_name) + kb_files = [(file, knowledge_base_name) for file in files] + i = 0 + for status, result in files2docs_in_thread(kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance): + if status: + kb_name, file_name, docs = result + kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) + kb_file.splited_docs = docs yield json.dumps({ "code": 200, - "msg": f"({i + 1} / {len(docs)}): {doc}", - "total": len(docs), + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), "finished": i, - "doc": doc, + "doc": file_name, }, ensure_ascii=False) - if i == len(docs) - 1: - not_refresh_vs_cache = False - else: - not_refresh_vs_cache = True - kb.add_doc(kb_file, not_refresh_vs_cache=not_refresh_vs_cache) - except Exception as e: - print(e) + kb.add_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" + logger.error(msg) yield json.dumps({ "code": 500, - "msg": f"添加文件‘{doc}’到知识库‘{knowledge_base_name}’时出错:{e}。已跳过。", + "msg": msg, }) + i += 1 return StreamingResponse(output(), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ca0919e0..c97f8cce 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -25,7 +25,6 @@ from server.knowledge_base.utils import ( list_kbs_from_folder, list_files_from_folder, ) from server.utils import embedding_device -from typing import List, Union, Dict from typing import List, Union, Dict, Optional @@ -51,6 +50,12 @@ class KBService(ABC): def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: return load_embeddings(self.embed_model, embed_device) + def save_vector_store(self): + ''' + 保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持 + ''' + pass + def create_kb(self): """ 创建知识库 @@ -84,6 +89,8 @@ class KBService(ABC): """ if docs: custom_docs = True + for doc in docs: + doc.metadata.setdefault("source", kb_file.filepath) else: docs = kb_file.file2text() custom_docs = False @@ -137,7 +144,6 @@ class KBService(ABC): docs = self.do_search(query, top_k, score_threshold, embeddings) return docs - # TODO: milvus/pg需要实现该方法 def get_doc_by_id(self, id: str) -> Optional[Document]: return None diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 15cc790b..6e20acf6 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -3,61 +3,16 @@ import shutil from configs.model_config import ( KB_ROOT_PATH, - CACHED_VS_NUM, - EMBEDDING_MODEL, - SCORE_THRESHOLD + SCORE_THRESHOLD, + logger, log_verbose, ) from server.knowledge_base.kb_service.base import KBService, SupportedVSType -from functools import lru_cache -from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile -from langchain.vectorstores import FAISS +from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss +from server.knowledge_base.utils import KnowledgeFile from langchain.embeddings.base import Embeddings from typing import List, Dict, Optional from langchain.docstore.document import Document -from server.utils import torch_gc, embedding_device - - -_VECTOR_STORE_TICKS = {} - - -@lru_cache(CACHED_VS_NUM) -def load_faiss_vector_store( - knowledge_base_name: str, - embed_model: str = EMBEDDING_MODEL, - embed_device: str = embedding_device(), - embeddings: Embeddings = None, - tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. -) -> FAISS: - print(f"loading vector store in '{knowledge_base_name}'.") - vs_path = get_vs_path(knowledge_base_name) - if embeddings is None: - embeddings = load_embeddings(embed_model, embed_device) - - if not os.path.exists(vs_path): - os.makedirs(vs_path) - - if "index.faiss" in os.listdir(vs_path): - search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True) - else: - # create an empty vector store - doc = Document(page_content="init", metadata={}) - search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True) - ids = [k for k, v in search_index.docstore._dict.items()] - search_index.delete(ids) - search_index.save_local(vs_path) - - if tick == 0: # vector store is loaded first time - _VECTOR_STORE_TICKS[knowledge_base_name] = 0 - - return search_index - - -def refresh_vs_cache(kb_name: str): - """ - make vector store cache refreshed when next loading - """ - _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 - print(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}") +from server.utils import torch_gc class FaissKBService(KBService): @@ -73,24 +28,15 @@ class FaissKBService(KBService): def get_kb_path(self): return os.path.join(KB_ROOT_PATH, self.kb_name) - def load_vector_store(self) -> FAISS: - return load_faiss_vector_store( - knowledge_base_name=self.kb_name, - embed_model=self.embed_model, - tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), - ) + def load_vector_store(self) -> ThreadSafeFaiss: + return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model) - def save_vector_store(self, vector_store: FAISS = None): - vector_store = vector_store or self.load_vector_store() - vector_store.save_local(self.vs_path) - return vector_store - - def refresh_vs_cache(self): - refresh_vs_cache(self.kb_name) + def save_vector_store(self): + self.load_vector_store().save(self.vs_path) def get_doc_by_id(self, id: str) -> Optional[Document]: - vector_store = self.load_vector_store() - return vector_store.docstore._dict.get(id) + with self.load_vector_store().acquire() as vs: + return vs.docstore._dict.get(id) def do_init(self): self.kb_path = self.get_kb_path() @@ -111,43 +57,38 @@ class FaissKBService(KBService): score_threshold: float = SCORE_THRESHOLD, embeddings: Embeddings = None, ) -> List[Document]: - search_index = self.load_vector_store() - docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) + with self.load_vector_store().acquire() as vs: + docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) return docs def do_add_doc(self, docs: List[Document], **kwargs, ) -> List[Dict]: - vector_store = self.load_vector_store() - ids = vector_store.add_documents(docs) + with self.load_vector_store().acquire() as vs: + ids = vs.add_documents(docs) + if not kwargs.get("not_refresh_vs_cache"): + vs.save_local(self.vs_path) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] torch_gc() - if not kwargs.get("not_refresh_vs_cache"): - vector_store.save_local(self.vs_path) - self.refresh_vs_cache() return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - vector_store = self.load_vector_store() - - ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] - if len(ids) == 0: - return None - - vector_store.delete(ids) - if not kwargs.get("not_refresh_vs_cache"): - vector_store.save_local(self.vs_path) - self.refresh_vs_cache() - - return vector_store + with self.load_vector_store().acquire() as vs: + ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath] + if len(ids) > 0: + vs.delete(ids) + if not kwargs.get("not_refresh_vs_cache"): + vs.save_local(self.vs_path) + return ids def do_clear_vs(self): + with kb_faiss_pool.atomic: + kb_faiss_pool.pop(self.kb_name) shutil.rmtree(self.vs_path) os.makedirs(self.vs_path) - self.refresh_vs_cache() def exist_doc(self, file_name: str): if super().exist_doc(file_name): @@ -165,4 +106,4 @@ if __name__ == '__main__': faissService.add_doc(KnowledgeFile("README.md", "test")) faissService.delete_doc(KnowledgeFile("README.md", "test")) faissService.do_drop_kb() - print(faissService.search_docs("如何启动api服务")) \ No newline at end of file + print(faissService.search_docs("如何启动api服务")) diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 444765f6..5ca425b5 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -22,6 +22,10 @@ class MilvusKBService(KBService): from pymilvus import Collection return Collection(milvus_name) + def save_vector_store(self): + if self.milvus.col: + self.milvus.col.flush() + def get_doc_by_id(self, id: str) -> Optional[Document]: if self.milvus.col: data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"]) @@ -56,6 +60,7 @@ class MilvusKBService(KBService): def do_drop_kb(self): if self.milvus.col: + self.milvus.col.release() self.milvus.col.drop() def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): @@ -63,6 +68,15 @@ class MilvusKBService(KBService): return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: + # TODO: workaround for bug #10492 in langchain + for doc in docs: + for k, v in doc.metadata.items(): + doc.metadata[k] = str(v) + for field in self.milvus.fields: + doc.metadata.setdefault(field, "") + doc.metadata.pop(self.milvus._text_field, None) + doc.metadata.pop(self.milvus._vector_field, None) + ids = self.milvus.add_documents(docs) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] return doc_infos @@ -76,7 +90,8 @@ class MilvusKBService(KBService): def do_clear_vs(self): if self.milvus.col: - self.milvus.col.drop() + self.do_drop_kb() + self.do_init() if __name__ == '__main__': diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index e6381fac..fa832ab5 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -26,6 +26,7 @@ class PGKBService(KBService): collection_name=self.kb_name, distance_strategy=DistanceStrategy.EUCLIDEAN, connection_string=kbs_config.get("pg").get("connection_uri")) + def get_doc_by_id(self, id: str) -> Optional[Document]: with self.pg_vector.connect() as connect: stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id") @@ -77,6 +78,7 @@ class PGKBService(KBService): def do_clear_vs(self): self.pg_vector.delete_collection() + self.pg_vector.create_collection() if __name__ == '__main__': diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index c11c00d2..b2073a10 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -1,19 +1,15 @@ -from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE +from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, + logger, log_verbose) from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, - list_files_from_folder, run_in_thread_pool, - files2docs_in_thread, + list_files_from_folder,files2docs_in_thread, KnowledgeFile,) from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType from server.db.repository.knowledge_file_repository import add_file_to_db from server.db.base import Base, engine import os -from concurrent.futures import ThreadPoolExecutor from typing import Literal, Any, List -pool = ThreadPoolExecutor(os.cpu_count()) - - def create_tables(): Base.metadata.create_all(bind=engine) @@ -30,7 +26,9 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]: kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name) kb_files.append(kb_file) except Exception as e: - print(f"{e},已跳过") + msg = f"{e},已跳过" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) return kb_files @@ -39,6 +37,9 @@ def folder2db( mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, + chunk_size: int = -1, + chunk_overlap: int = -1, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, ): ''' use existed files in local folder to populate database and/or vector store. @@ -59,7 +60,10 @@ def folder2db( print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。") kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) - for success, result in files2docs_in_thread(kb_files, pool=pool): + for success, result in files2docs_in_thread(kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance): if success: _, filename, docs = result print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档") @@ -67,10 +71,7 @@ def folder2db( kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) else: print(result) - - if kb.vs_type() == SupportedVSType.FAISS: - kb.save_vector_store() - kb.refresh_vs_cache() + kb.save_vector_store() elif mode == "fill_info_only": files = list_files_from_folder(kb_name) kb_files = file_to_kbfile(kb_name, files) @@ -84,17 +85,17 @@ def folder2db( for kb_file in kb_files: kb.update_doc(kb_file, not_refresh_vs_cache=True) - - if kb.vs_type() == SupportedVSType.FAISS: - kb.save_vector_store() - kb.refresh_vs_cache() + kb.save_vector_store() elif mode == "increament": db_files = kb.list_files() folder_files = list_files_from_folder(kb_name) files = list(set(folder_files) - set(db_files)) kb_files = file_to_kbfile(kb_name, files) - for success, result in files2docs_in_thread(kb_files, pool=pool): + for success, result in files2docs_in_thread(kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance): if success: _, filename, docs = result print(f"正在将 {kb_name}/{filename} 添加到向量库") @@ -102,10 +103,7 @@ def folder2db( kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True) else: print(result) - - if kb.vs_type() == SupportedVSType.FAISS: - kb.save_vector_store() - kb.refresh_vs_cache() + kb.save_vector_store() else: print(f"unspported migrate mode: {mode}") @@ -135,9 +133,7 @@ def prune_db_files(kb_name: str): kb_files = file_to_kbfile(kb_name, files) for kb_file in kb_files: kb.delete_doc(kb_file, not_refresh_vs_cache=True) - if kb.vs_type() == SupportedVSType.FAISS: - kb.save_vector_store() - kb.refresh_vs_cache() + kb.save_vector_store() return kb_files def prune_folder_files(kb_name: str): diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index fe7fce1c..ab4c2f95 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,39 +1,33 @@ import os -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.embeddings import HuggingFaceBgeEmbeddings + +from transformers import AutoTokenizer + from configs.model_config import ( - embedding_model_dict, + EMBEDDING_MODEL, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE, - ZH_TITLE_ENHANCE + ZH_TITLE_ENHANCE, + logger, + log_verbose, + text_splitter_dict, + llm_model_dict, + LLM_MODEL, + TEXT_SPLITTER ) -from functools import lru_cache import importlib from text_splitter import zh_title_enhance import langchain.document_loaders from langchain.docstore.document import Document +from langchain.text_splitter import TextSplitter from pathlib import Path import json -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor +from server.utils import run_in_thread_pool, embedding_device +import io from typing import List, Union, Callable, Dict, Optional, Tuple, Generator -# make HuggingFaceEmbeddings hashable -def _embeddings_hash(self): - if isinstance(self, HuggingFaceEmbeddings): - return hash(self.model_name) - elif isinstance(self, HuggingFaceBgeEmbeddings): - return hash(self.model_name) - elif isinstance(self, OpenAIEmbeddings): - return hash(self.model) - -HuggingFaceEmbeddings.__hash__ = _embeddings_hash -OpenAIEmbeddings.__hash__ = _embeddings_hash -HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash - - def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 if "../" in knowledge_base_id: @@ -68,19 +62,12 @@ def list_files_from_folder(kb_name: str): if os.path.isfile(os.path.join(doc_path, file))] -@lru_cache(1) -def load_embeddings(model: str, device: str): - if model == "text-embedding-ada-002": # openai text-embedding-ada-002 - embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) - elif 'bge-' in model: - embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], - model_kwargs={'device': device}, - query_instruction="为这个句子生成表示以用于检索相关文章:") - if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding - embeddings.query_instruction = "" - else: - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) - return embeddings +def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()): + ''' + 从缓存中加载embeddings,可以避免多线程时竞争加载。 + ''' + from server.knowledge_base.kb_cache.base import embeddings_pool + return embeddings_pool.load_embeddings(model=model, device=device) LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], @@ -99,16 +86,16 @@ SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] class CustomJSONLoader(langchain.document_loaders.JSONLoader): ''' - langchain的JSONLoader需要jq,在win上使用不便,进行替代。 + langchain的JSONLoader需要jq,在win上使用不便,进行替代。针对langchain==0.0.286 ''' def __init__( - self, - file_path: Union[str, Path], - content_key: Optional[str] = None, - metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, - text_content: bool = True, - json_lines: bool = False, + self, + file_path: Union[str, Path], + content_key: Optional[str] = None, + metadata_func: Optional[Callable[[Dict, Dict], Dict]] = None, + text_content: bool = True, + json_lines: bool = False, ): """Initialize the JSONLoader. @@ -130,21 +117,6 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader): self._text_content = text_content self._json_lines = json_lines - # TODO: langchain's JSONLoader.load has a encoding bug, raise gbk encoding error on windows. - # This is a workaround for langchain==0.0.266. I have make a pr(#9785) to langchain, it should be deleted after langchain upgraded. - def load(self) -> List[Document]: - """Load and return documents from the JSON file.""" - docs: List[Document] = [] - if self._json_lines: - with self.file_path.open(encoding="utf-8") as f: - for line in f: - line = line.strip() - if line: - self._parse(line, docs) - else: - self._parse(self.file_path.read_text(encoding="utf-8"), docs) - return docs - def _parse(self, content: str, docs: List[Document]) -> None: """Convert given content to documents.""" data = json.loads(content) @@ -154,13 +126,14 @@ class CustomJSONLoader(langchain.document_loaders.JSONLoader): # and prevent the user from getting a cryptic error later on. if self._content_key is not None: self._validate_content_key(data) + if self._metadata_func is not None: + self._validate_metadata_func(data) for i, sample in enumerate(data, len(docs) + 1): - metadata = dict( - source=str(self.file_path), - seq_num=i, + text = self._get_text(sample=sample) + metadata = self._get_metadata( + sample=sample, source=str(self.file_path), seq_num=i ) - text = self._get_text(sample=sample, metadata=metadata) docs.append(Document(page_content=text, metadata=metadata)) @@ -173,12 +146,124 @@ def get_LoaderClass(file_extension): return LoaderClass +# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化 +def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.StringIO, io.BytesIO]): + ''' + 根据loader_name和文件路径或内容返回文档加载器。 + ''' + try: + if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: + document_loaders_module = importlib.import_module('document_loaders') + else: + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, loader_name) + except Exception as e: + msg = f"为文件{file_path_or_content}查找加载器{loader_name}时出错:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + document_loaders_module = importlib.import_module('langchain.document_loaders') + DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") + + if loader_name == "UnstructuredFileLoader": + loader = DocumentLoader(file_path_or_content, autodetect_encoding=True) + elif loader_name == "CSVLoader": + loader = DocumentLoader(file_path_or_content, encoding="utf-8") + elif loader_name == "JSONLoader": + loader = DocumentLoader(file_path_or_content, jq_schema=".", text_content=False) + elif loader_name == "CustomJSONLoader": + loader = DocumentLoader(file_path_or_content, text_content=False) + elif loader_name == "UnstructuredMarkdownLoader": + loader = DocumentLoader(file_path_or_content, mode="elements") + elif loader_name == "UnstructuredHTMLLoader": + loader = DocumentLoader(file_path_or_content, mode="elements") + else: + loader = DocumentLoader(file_path_or_content) + return loader + + +def make_text_splitter( + splitter_name: str = TEXT_SPLITTER, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, +): + """ + 根据参数获取特定的分词器 + """ + splitter_name = splitter_name or "SpacyTextSplitter" + try: + if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定 + headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on'] + text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on) + else: + + try: ## 优先使用用户自定义的text_splitter + text_splitter_module = importlib.import_module('text_splitter') + TextSplitter = getattr(text_splitter_module, splitter_name) + except: ## 否则使用langchain的text_splitter + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, splitter_name) + + if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载 + try: + text_splitter = TextSplitter.from_tiktoken_encoder( + encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + except: + text_splitter = TextSplitter.from_tiktoken_encoder( + encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载 + if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": + text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \ + llm_model_dict[LLM_MODEL]["local_model_path"] + + if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": + from transformers import GPT2TokenizerFast + from langchain.text_splitter import CharacterTextSplitter + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + else: ## 字符长度加载 + tokenizer = AutoTokenizer.from_pretrained( + text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + trust_remote_code=True) + text_splitter = TextSplitter.from_huggingface_tokenizer( + tokenizer=tokenizer, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + else: + try: + text_splitter = TextSplitter( + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + except: + text_splitter = TextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + except Exception as e: + print(e) + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") + text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50) + return text_splitter + class KnowledgeFile: def __init__( self, filename: str, knowledge_base_name: str ): + ''' + 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 + ''' self.kb_name = knowledge_base_name self.filename = filename self.ext = os.path.splitext(filename)[-1].lower() @@ -186,75 +271,67 @@ class KnowledgeFile: raise ValueError(f"暂未支持的文件格式 {self.ext}") self.filepath = get_file_path(knowledge_base_name, filename) self.docs = None + self.splited_docs = None self.document_loader_name = get_LoaderClass(self.ext) + self.text_splitter_name = TEXT_SPLITTER - # TODO: 增加依据文件格式匹配text_splitter - self.text_splitter_name = None + def file2docs(self, refresh: bool=False): + if self.docs is None or refresh: + logger.info(f"{self.document_loader_name} used for {self.filepath}") + loader = get_loader(self.document_loader_name, self.filepath) + self.docs = loader.load() + return self.docs - def file2text(self, using_zh_title_enhance=ZH_TITLE_ENHANCE, refresh: bool = False): - if self.docs is not None and not refresh: - return self.docs - - print(f"{self.document_loader_name} used for {self.filepath}") - try: - if self.document_loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader"]: - document_loaders_module = importlib.import_module('document_loaders') + def docs2texts( + self, + docs: List[Document] = None, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + docs = docs or self.file2docs(refresh=refresh) + if not docs: + return [] + if self.ext not in [".csv"]: + if text_splitter is None: + text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + if self.text_splitter_name == "MarkdownHeaderTextSplitter": + docs = text_splitter.split_text(docs[0].page_content) + for doc in docs: + # 如果文档有元数据 + if doc.metadata: + doc.metadata["source"] = os.path.basename(self.filepath) else: - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, self.document_loader_name) - except Exception as e: - print(e) - document_loaders_module = importlib.import_module('langchain.document_loaders') - DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") - if self.document_loader_name == "UnstructuredFileLoader": - loader = DocumentLoader(self.filepath, autodetect_encoding=True) - elif self.document_loader_name == "CSVLoader": - loader = DocumentLoader(self.filepath, encoding="utf-8") - elif self.document_loader_name == "JSONLoader": - loader = DocumentLoader(self.filepath, jq_schema=".", text_content=False) - elif self.document_loader_name == "CustomJSONLoader": - loader = DocumentLoader(self.filepath, text_content=False) - elif self.document_loader_name == "UnstructuredMarkdownLoader": - loader = DocumentLoader(self.filepath, mode="elements") - elif self.document_loader_name == "UnstructuredHTMLLoader": - loader = DocumentLoader(self.filepath, mode="elements") - else: - loader = DocumentLoader(self.filepath) + docs = text_splitter.split_documents(docs) - if self.ext in ".csv": - docs = loader.load() - else: - try: - if self.text_splitter_name is None: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "SpacyTextSplitter") - text_splitter = TextSplitter( - pipeline="zh_core_web_sm", - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - ) - self.text_splitter_name = "SpacyTextSplitter" - else: - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, self.text_splitter_name) - text_splitter = TextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE) - except Exception as e: - print(e) - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") - text_splitter = TextSplitter( - chunk_size=CHUNK_SIZE, - chunk_overlap=OVERLAP_SIZE, - ) - - docs = loader.load_and_split(text_splitter) - print(docs[0]) - if using_zh_title_enhance: + print(f"文档切分示例:{docs[0]}") + if zh_title_enhance: docs = zh_title_enhance(docs) - self.docs = docs - return docs + self.splited_docs = docs + return self.splited_docs + + def file2text( + self, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + if self.splited_docs is None or refresh: + docs = self.file2docs() + self.splited_docs = self.docs2texts(docs=docs, + zh_title_enhance=zh_title_enhance, + refresh=refresh, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + text_splitter=text_splitter) + return self.splited_docs + + def file_exist(self): + return os.path.isfile(self.filepath) def get_mtime(self): return os.path.getmtime(self.filepath) @@ -263,53 +340,54 @@ class KnowledgeFile: return os.path.getsize(self.filepath) -def run_in_thread_pool( - func: Callable, - params: List[Dict] = [], - pool: ThreadPoolExecutor = None, -) -> Generator: - ''' - 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 - 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 - ''' - tasks = [] - if pool is None: - pool = ThreadPoolExecutor() - - for kwargs in params: - thread = pool.submit(func, **kwargs) - tasks.append(thread) - - for obj in as_completed(tasks): - yield obj.result() - - def files2docs_in_thread( - files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], - pool: ThreadPoolExecutor = None, + files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, + pool: ThreadPoolExecutor = None, ) -> Generator: ''' - 利用多线程批量将文件转化成langchain Document. - 生成器返回值为{(kb_name, file_name): docs} + 利用多线程批量将磁盘文件转化成langchain Document. + 如果传入参数是Tuple,形式为(filename, kb_name) + 生成器返回值为 status, (kb_name, file_name, docs | error) ''' - def task(*, file: KnowledgeFile, **kwargs) -> Dict[Tuple[str, str], List[Document]]: + def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]: try: return True, (file.kb_name, file.filename, file.file2text(**kwargs)) except Exception as e: - return False, e + msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + return False, (file.kb_name, file.filename, msg) kwargs_list = [] for i, file in enumerate(files): kwargs = {} if isinstance(file, tuple) and len(file) >= 2: - files[i] = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) + file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1]) elif isinstance(file, dict): filename = file.pop("filename") kb_name = file.pop("kb_name") - files[i] = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kwargs = file + file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) kwargs["file"] = file + kwargs["chunk_size"] = chunk_size + kwargs["chunk_overlap"] = chunk_overlap + kwargs["zh_title_enhance"] = zh_title_enhance kwargs_list.append(kwargs) - - for result in run_in_thread_pool(func=task, params=kwargs_list, pool=pool): + + for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool): yield result + + +if __name__ == "__main__": + from pprint import pprint + + kb_file = KnowledgeFile(filename="test.txt", knowledge_base_name="samples") + # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" + docs = kb_file.file2docs() + pprint(docs[-1]) + + docs = kb_file.file2text() + pprint(docs[-1]) \ No newline at end of file diff --git a/server/llm_api.py b/server/llm_api.py index d9667e4f..5843e89a 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,279 +1,71 @@ -from multiprocessing import Process, Queue -import multiprocessing as mp -import sys -import os - -sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger -from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device +from fastapi import Body +from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT +from server.utils import BaseResponse, fschat_controller_address +import httpx -host_ip = "0.0.0.0" -controller_port = 20001 -model_worker_port = 20002 -openai_api_port = 8888 -base_url = "http://127.0.0.1:{}" +def list_llm_models( + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]), + placeholder: str = Body(None, description="该参数未使用,占位用"), +) -> BaseResponse: + ''' + 从fastchat controller获取已加载模型列表 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post(controller_address + "/list_models") + return BaseResponse(data=r.json()["models"]) + except Exception as e: + logger.error(f'{e.__class__.__name__}: {e}', + exc_info=e if log_verbose else None) + return BaseResponse( + code=500, + data=[], + msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") -def create_controller_app( - dispatch_method="shortest_queue", +def stop_llm_model( + model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]), + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) +) -> BaseResponse: + ''' + 向fastchat controller请求停止某个LLM模型。 + 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post( + controller_address + "/release_worker", + json={"model_name": model_name}, + ) + return r.json() + except Exception as e: + logger.error(f'{e.__class__.__name__}: {e}', + exc_info=e if log_verbose else None) + return BaseResponse( + code=500, + msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}") + + +def change_llm_model( + model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]), + new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]), + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) ): - import fastchat.constants - fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.controller import app, Controller - - controller = Controller(dispatch_method) - sys.modules["fastchat.serve.controller"].controller = controller - - MakeFastAPIOffline(app) - app.title = "FastChat Controller" - return app - - -def create_model_worker_app( - worker_address=base_url.format(model_worker_port), - controller_address=base_url.format(controller_port), - model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), - device=llm_device(), - gpus=None, - max_gpu_memory="20GiB", - load_8bit=False, - cpu_offloading=None, - gptq_ckpt=None, - gptq_wbits=16, - gptq_groupsize=-1, - gptq_act_order=False, - awq_ckpt=None, - awq_wbits=16, - awq_groupsize=-1, - model_names=[LLM_MODEL], - num_gpus=1, # not in fastchat - conv_template=None, - limit_worker_concurrency=5, - stream_interval=2, - no_register=False, -): - import fastchat.constants - fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id - import argparse - import threading - import fastchat.serve.model_worker - - # workaround to make program exit with Ctrl+c - # it should be deleted after pr is merged by fastchat - def _new_init_heart_beat(self): - self.register_to_controller() - self.heart_beat_thread = threading.Thread( - target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ''' + 向fastchat controller请求切换LLM模型。 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post( + controller_address + "/release_worker", + json={"model_name": model_name, "new_model_name": new_model_name}, + timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model ) - self.heart_beat_thread.start() - ModelWorker.init_heart_beat = _new_init_heart_beat - - parser = argparse.ArgumentParser() - args = parser.parse_args() - args.model_path = model_path - args.model_names = model_names - args.device = device - args.load_8bit = load_8bit - args.gptq_ckpt = gptq_ckpt - args.gptq_wbits = gptq_wbits - args.gptq_groupsize = gptq_groupsize - args.gptq_act_order = gptq_act_order - args.awq_ckpt = awq_ckpt - args.awq_wbits = awq_wbits - args.awq_groupsize = awq_groupsize - args.gpus = gpus - args.num_gpus = num_gpus - args.max_gpu_memory = max_gpu_memory - args.cpu_offloading = cpu_offloading - args.worker_address = worker_address - args.controller_address = controller_address - args.conv_template = conv_template - args.limit_worker_concurrency = limit_worker_concurrency - args.stream_interval = stream_interval - args.no_register = no_register - - if args.gpus: - if len(args.gpus.split(",")) < args.num_gpus: - raise ValueError( - f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" - ) - os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus - - if gpus and num_gpus is None: - num_gpus = len(gpus.split(',')) - args.num_gpus = num_gpus - - gptq_config = GptqConfig( - ckpt=gptq_ckpt or model_path, - wbits=args.gptq_wbits, - groupsize=args.gptq_groupsize, - act_order=args.gptq_act_order, - ) - awq_config = AWQConfig( - ckpt=args.awq_ckpt or args.model_path, - wbits=args.awq_wbits, - groupsize=args.awq_groupsize, - ) - - # torch.multiprocessing.set_start_method('spawn') - worker = ModelWorker( - controller_addr=args.controller_address, - worker_addr=args.worker_address, - worker_id=worker_id, - model_path=args.model_path, - model_names=args.model_names, - limit_worker_concurrency=args.limit_worker_concurrency, - no_register=args.no_register, - device=args.device, - num_gpus=args.num_gpus, - max_gpu_memory=args.max_gpu_memory, - load_8bit=args.load_8bit, - cpu_offloading=args.cpu_offloading, - gptq_config=gptq_config, - awq_config=awq_config, - stream_interval=args.stream_interval, - conv_template=args.conv_template, - ) - - sys.modules["fastchat.serve.model_worker"].worker = worker - sys.modules["fastchat.serve.model_worker"].args = args - sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config - - MakeFastAPIOffline(app) - app.title = f"FastChat LLM Server ({LLM_MODEL})" - return app - - -def create_openai_api_app( - controller_address=base_url.format(controller_port), - api_keys=[], -): - import fastchat.constants - fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings - - app.add_middleware( - CORSMiddleware, - allow_credentials=True, - allow_origins=["*"], - allow_methods=["*"], - allow_headers=["*"], - ) - - app_settings.controller_address = controller_address - app_settings.api_keys = api_keys - - MakeFastAPIOffline(app) - app.title = "FastChat OpeanAI API Server" - return app - - -def run_controller(q): - import uvicorn - app = create_controller_app() - - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - q.put(1) - - uvicorn.run(app, host=host_ip, port=controller_port) - - -def run_model_worker(q, *args, **kwargs): - import uvicorn - app = create_model_worker_app(*args, **kwargs) - - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - while True: - no = q.get() - if no != 1: - q.put(no) - else: - break - q.put(2) - - uvicorn.run(app, host=host_ip, port=model_worker_port) - - -def run_openai_api(q): - import uvicorn - app = create_openai_api_app() - - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - while True: - no = q.get() - if no != 2: - q.put(no) - else: - break - q.put(3) - - uvicorn.run(app, host=host_ip, port=openai_api_port) - - -if __name__ == "__main__": - mp.set_start_method("spawn") - queue = Queue() - logger.info(llm_model_dict[LLM_MODEL]) - model_path = llm_model_dict[LLM_MODEL]["local_model_path"] - - logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") - - if not model_path: - logger.error("local_model_path 不能为空") - else: - controller_process = Process( - target=run_controller, - name=f"controller({os.getpid()})", - args=(queue,), - daemon=True, - ) - controller_process.start() - - model_worker_process = Process( - target=run_model_worker, - name=f"model_worker({os.getpid()})", - args=(queue,), - # kwargs={"load_8bit": True}, - daemon=True, - ) - model_worker_process.start() - - openai_api_process = Process( - target=run_openai_api, - name=f"openai_api({os.getpid()})", - args=(queue,), - daemon=True, - ) - openai_api_process.start() - - try: - model_worker_process.join() - controller_process.join() - openai_api_process.join() - except KeyboardInterrupt: - model_worker_process.terminate() - controller_process.terminate() - openai_api_process.terminate() - -# 服务启动后接口调用示例: -# import openai -# openai.api_key = "EMPTY" # Not support yet -# openai.api_base = "http://localhost:8888/v1" - -# model = "chatglm2-6b" - -# # create a chat completion -# completion = openai.ChatCompletion.create( -# model=model, -# messages=[{"role": "user", "content": "Hello! What is your name?"}] -# ) -# # print the completion -# print(completion.choices[0].message.content) + return r.json() + except Exception as e: + logger.error(f'{e.__class__.__name__}: {e}', + exc_info=e if log_verbose else None) + return BaseResponse( + code=500, + msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") diff --git a/server/model_workers/SparkApi.py b/server/model_workers/SparkApi.py new file mode 100644 index 00000000..e1dce6a0 --- /dev/null +++ b/server/model_workers/SparkApi.py @@ -0,0 +1,79 @@ +import base64 +import datetime +import hashlib +import hmac +from urllib.parse import urlparse +from datetime import datetime +from time import mktime +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time + + +class Ws_Param(object): + # 初始化 + def __init__(self, APPID, APIKey, APISecret, Spark_url): + self.APPID = APPID + self.APIKey = APIKey + self.APISecret = APISecret + self.host = urlparse(Spark_url).netloc + self.path = urlparse(Spark_url).path + self.Spark_url = Spark_url + + # 生成url + def create_url(self): + # 生成RFC1123格式的时间戳 + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + # 拼接字符串 + signature_origin = "host: " + self.host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + self.path + " HTTP/1.1" + + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' + + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": self.host + } + # 拼接鉴权参数,生成url + url = self.Spark_url + '?' + urlencode(v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + return url + + +def gen_params(appid, domain,question, temperature): + """ + 通过appid和用户的提问来生成请参数 + """ + data = { + "header": { + "app_id": appid, + "uid": "1234" + }, + "parameter": { + "chat": { + "domain": domain, + "random_threshold": 0.5, + "max_tokens": 2048, + "auditing": "default", + "temperature": temperature, + } + }, + "payload": { + "message": { + "text": question + } + } + } + return data diff --git a/server/model_workers/__init__.py b/server/model_workers/__init__.py index 932c2f3b..a3a162f3 100644 --- a/server/model_workers/__init__.py +++ b/server/model_workers/__init__.py @@ -1 +1,4 @@ from .zhipu import ChatGLMWorker +from .minimax import MiniMaxWorker +from .xinghuo import XingHuoWorker +from .qianfan import QianFanWorker diff --git a/server/model_workers/base.py b/server/model_workers/base.py index b72f6839..df5fbfcc 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -69,3 +69,28 @@ class ApiModelWorker(BaseModelWorker): target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, ) self.heart_beat_thread.start() + + # help methods + def get_config(self): + from server.utils import get_model_worker_config + return get_model_worker_config(self.model_names[0]) + + def prompt_to_messages(self, prompt: str) -> List[Dict]: + ''' + 将prompt字符串拆分成messages. + ''' + result = [] + user_role = self.conv.roles[0] + ai_role = self.conv.roles[1] + user_start = user_role + ":" + ai_start = ai_role + ":" + for msg in prompt.split(self.conv.sep)[1:-1]: + if msg.startswith(user_start): + if content := msg[len(user_start):].strip(): + result.append({"role": user_role, "content": content}) + elif msg.startswith(ai_start): + if content := msg[len(ai_start):].strip(): + result.append({"role": ai_role, "content": content}) + else: + raise RuntimeError(f"unknow role in msg: {msg}") + return result diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py new file mode 100644 index 00000000..c772c0dc --- /dev/null +++ b/server/model_workers/minimax.py @@ -0,0 +1,100 @@ +from server.model_workers.base import ApiModelWorker +from fastchat import conversation as conv +import sys +import json +import httpx +from pprint import pprint +from typing import List, Dict + + +class MiniMaxWorker(ApiModelWorker): + BASE_URL = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}' + + def __init__( + self, + *, + model_names: List[str] = ["minimax-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 16384) + super().__init__(**kwargs) + + # TODO: 确认模板是否需要修改 + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["USER", "BOT"], + sep="\n### ", + stop_str="###", + ) + + def prompt_to_messages(self, prompt: str) -> List[Dict]: + result = super().prompt_to_messages(prompt) + messages = [{"sender_type": x["role"], "text": x["content"]} for x in result] + return messages + + def generate_stream_gate(self, params): + # 按照官网推荐,直接调用abab 5.5模型 + # TODO: 支持指定回复要求,支持指定用户名称、AI名称 + + super().generate_stream_gate(params) + config = self.get_config() + group_id = config.get("group_id") + api_key = config.get("api_key") + + pro = "_pro" if config.get("is_pro") else "" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + data = { + "model": "abab5.5-chat", + "stream": True, + "tokens_to_generate": 1024, # TODO: 1024为官网默认值 + "mask_sensitive_info": True, + "messages": self.prompt_to_messages(params["prompt"]), + "temperature": params.get("temperature"), + "top_p": params.get("top_p"), + "bot_setting": [], + } + print("request data sent to minimax:") + pprint(data) + response = httpx.stream("POST", + self.BASE_URL.format(pro=pro, group_id=group_id), + headers=headers, + json=data) + with response as r: + text = "" + for e in r.iter_text(): + if e.startswith("data: "): # 真是优秀的返回 + data = json.loads(e[6:]) + if not data.get("usage"): + if choices := data.get("choices"): + chunk = choices[0].get("delta", "").strip() + if chunk: + print(chunk) + text += chunk + yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = MiniMaxWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20004", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20003) diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py new file mode 100644 index 00000000..8a593a7e --- /dev/null +++ b/server/model_workers/qianfan.py @@ -0,0 +1,172 @@ +from server.model_workers.base import ApiModelWorker +from configs.model_config import TEMPERATURE +from fastchat import conversation as conv +import sys +import json +import httpx +from cachetools import cached, TTLCache +from server.utils import get_model_worker_config +from typing import List, Literal, Dict + + +MODEL_VERSIONS = { + "ernie-bot": "completions", + "ernie-bot-turbo": "eb-instant", + "bloomz-7b": "bloomz_7b1", + "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed", + "llama2-7b-chat": "llama_2_7b", + "llama2-13b-chat": "llama_2_13b", + "llama2-70b-chat": "llama_2_70b", + "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b", + "chatglm2-6b-32k": "chatglm2_6b_32k", + "aquilachat-7b": "aquilachat_7b", + # "linly-llama2-ch-7b": "", # 暂未发布 + # "linly-llama2-ch-13b": "", # 暂未发布 + # "chatglm2-6b": "", # 暂未发布 + # "chatglm2-6b-int4": "", # 暂未发布 + # "falcon-7b": "", # 暂未发布 + # "falcon-180b-chat": "", # 暂未发布 + # "falcon-40b": "", # 暂未发布 + # "rwkv4-world": "", # 暂未发布 + # "rwkv5-world": "", # 暂未发布 + # "rwkv4-pile-14b": "", # 暂未发布 + # "rwkv4-raven-14b": "", # 暂未发布 + # "open-llama-7b": "", # 暂未发布 + # "dolly-12b": "", # 暂未发布 + # "mpt-7b-instruct": "", # 暂未发布 + # "mpt-30b-instruct": "", # 暂未发布 + # "OA-Pythia-12B-SFT-4": "", # 暂未发布 + # "xverse-13b": "", # 暂未发布 + + # # 以下为企业测试,需要单独申请 + # "flan-ul2": "", + # "Cerebras-GPT-6.7B": "" + # "Pythia-6.9B": "" +} + + +@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次 +def get_baidu_access_token(api_key: str, secret_key: str) -> str: + """ + 使用 AK,SK 生成鉴权签名(Access Token) + :return: access_token,或是None(如果错误) + """ + url = "https://aip.baidubce.com/oauth/2.0/token" + params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} + try: + return httpx.get(url, params=params).json().get("access_token") + except Exception as e: + print(f"failed to get token from baidu: {e}") + + +def request_qianfan_api( + messages: List[Dict[str, str]], + temperature: float = TEMPERATURE, + model_name: str = "qianfan-api", + version: str = None, +) -> Dict: + BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\ + '/{model_version}?access_token={access_token}' + config = get_model_worker_config(model_name) + version = version or config.get("version") + version_url = config.get("version_url") + access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key")) + if not access_token: + raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?") + + url = BASE_URL.format( + model_version=version_url or MODEL_VERSIONS[version], + access_token=access_token, + ) + payload = { + "messages": messages, + "temperature": temperature, + "stream": True + } + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + } + + with httpx.stream("POST", url, headers=headers, json=payload) as response: + for line in response.iter_lines(): + if not line.strip(): + continue + if line.startswith("data: "): + line = line[6:] + resp = json.loads(line) + yield resp + + +class QianFanWorker(ApiModelWorker): + """ + 百度千帆 + """ + def __init__( + self, + *, + version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot", + model_names: List[str] = ["ernie-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 16384) + super().__init__(**kwargs) + + # TODO: 确认模板是否需要修改 + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + + config = self.get_config() + self.version = version + self.api_key = config.get("api_key") + self.secret_key = config.get("secret_key") + + def generate_stream_gate(self, params): + messages = self.prompt_to_messages(params["prompt"]) + text="" + for resp in request_qianfan_api(messages, + temperature=params.get("temperature"), + model_name=self.model_names[0]): + if "result" in resp.keys(): + text += resp["result"] + yield json.dumps({ + "error_code": 0, + "text": text + }, + ensure_ascii=False + ).encode() + b"\0" + else: + yield json.dumps({ + "error_code": resp["error_code"], + "text": resp["error_msg"] + }, + ensure_ascii=False + ).encode() + b"\0" + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = QianFanWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20006", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20006) \ No newline at end of file diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py new file mode 100644 index 00000000..499e8bc1 --- /dev/null +++ b/server/model_workers/xinghuo.py @@ -0,0 +1,101 @@ +from server.model_workers.base import ApiModelWorker +from fastchat import conversation as conv +import sys +import json +from server.model_workers import SparkApi +import websockets +from server.utils import iter_over_async, asyncio +from typing import List + + +async def request(appid, api_key, api_secret, Spark_url,domain, question, temperature): + # print("星火:") + wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url) + wsUrl = wsParam.create_url() + data = SparkApi.gen_params(appid, domain, question, temperature) + async with websockets.connect(wsUrl) as ws: + await ws.send(json.dumps(data, ensure_ascii=False)) + finish = False + while not finish: + chunk = await ws.recv() + response = json.loads(chunk) + if response.get("header", {}).get("status") == 2: + finish = True + if text := response.get("payload", {}).get("choices", {}).get("text"): + yield text[0]["content"] + + +class XingHuoWorker(ApiModelWorker): + def __init__( + self, + *, + model_names: List[str] = ["xinghuo-api"], + controller_addr: str, + worker_addr: str, + **kwargs, + ): + kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) + kwargs.setdefault("context_len", 8192) + super().__init__(**kwargs) + + # TODO: 确认模板是否需要修改 + self.conv = conv.Conversation( + name=self.model_names[0], + system_message="", + messages=[], + roles=["user", "assistant"], + sep="\n### ", + stop_str="###", + ) + + def generate_stream_gate(self, params): + # TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接 + + super().generate_stream_gate(params) + config = self.get_config() + appid = config.get("APPID") + api_secret = config.get("APISecret") + api_key = config.get("api_key") + + if config.get("is_v2"): + domain = "generalv2" # v2.0版本 + Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址 + else: + domain = "general" # v1.5版本 + Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址 + + question = self.prompt_to_messages(params["prompt"]) + text = "" + + try: + loop = asyncio.get_event_loop() + except: + loop = asyncio.new_event_loop() + + for chunk in iter_over_async( + request(appid, api_key, api_secret, Spark_url, domain, question, params.get("temperature")), + loop=loop, + ): + if chunk: + print(chunk) + text += chunk + yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0" + + def get_embeddings(self, params): + # TODO: 支持embeddings + print("embedding") + print(params) + + +if __name__ == "__main__": + import uvicorn + from server.utils import MakeFastAPIOffline + from fastchat.serve.model_worker import app + + worker = XingHuoWorker( + controller_addr="http://127.0.0.1:20001", + worker_addr="http://127.0.0.1:20005", + ) + sys.modules["fastchat.serve.model_worker"].worker = worker + MakeFastAPIOffline(app) + uvicorn.run(app, port=20005) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 4e4e15e0..f835ac06 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -1,4 +1,3 @@ -import zhipuai from server.model_workers.base import ApiModelWorker from fastchat import conversation as conv import sys @@ -13,7 +12,7 @@ class ChatGLMWorker(ApiModelWorker): def __init__( self, *, - model_names: List[str] = ["chatglm-api"], + model_names: List[str] = ["zhipu-api"], version: Literal["chatglm_pro", "chatglm_std", "chatglm_lite"] = "chatglm_std", controller_addr: str, worker_addr: str, @@ -26,7 +25,7 @@ class ChatGLMWorker(ApiModelWorker): # 这里的是chatglm api的模板,其它API的conv_template需要定制 self.conv = conv.Conversation( - name="chatglm-api", + name=self.model_names[0], system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", messages=[], roles=["Human", "Assistant"], @@ -35,11 +34,11 @@ class ChatGLMWorker(ApiModelWorker): ) def generate_stream_gate(self, params): - # TODO: 支持stream参数,维护request_id,传过来的prompt也有问题 - from server.utils import get_model_worker_config + # TODO: 维护request_id + import zhipuai super().generate_stream_gate(params) - zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key") + zhipuai.api_key = self.get_config().get("api_key") response = zhipuai.model_api.sse_invoke( model=self.version, diff --git a/server/utils.py b/server/utils.py index 0e53e3df..48e74353 100644 --- a/server/utils.py +++ b/server/utils.py @@ -4,11 +4,14 @@ from typing import List from fastapi import FastAPI from pathlib import Path import asyncio -from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE +from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE, logger, log_verbose from configs.server_config import FSCHAT_MODEL_WORKERS import os -from server import model_workers -from typing import Literal, Optional, Any +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Literal, Optional, Callable, Generator, Dict, Any + + +thread_pool = ThreadPoolExecutor(os.cpu_count()) class BaseResponse(BaseModel): @@ -82,9 +85,10 @@ def torch_gc(): from torch.mps import empty_cache empty_cache() except Exception as e: - print(e) - print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。") - + msg=("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," + "以支持及时清理 torch 产生的内存占用。") + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) def run_async(cor): ''' @@ -200,6 +204,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: 优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"] ''' from configs.server_config import FSCHAT_MODEL_WORKERS + from server import model_workers from configs.model_config import llm_model_dict config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() @@ -213,7 +218,9 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: try: config["worker_class"] = getattr(model_workers, provider) except Exception as e: - print(f"在线模型 ‘{model_name}’ 的provider没有正确配置") + msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) config["device"] = llm_device(config.get("device") or LLM_DEVICE) return config @@ -305,3 +312,24 @@ def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", " if device not in ["cuda", "mps", "cpu"]: device = detect_device() return device + + +def run_in_thread_pool( + func: Callable, + params: List[Dict] = [], + pool: ThreadPoolExecutor = None, +) -> Generator: + ''' + 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 + 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 + ''' + tasks = [] + pool = pool or thread_pool + + for kwargs in params: + thread = pool.submit(func, **kwargs) + tasks.append(thread) + + for obj in as_completed(tasks): + yield obj.result() + diff --git a/startup.py b/startup.py index 591e3b10..b3094445 100644 --- a/startup.py +++ b/startup.py @@ -3,7 +3,8 @@ import multiprocessing as mp import os import subprocess import sys -from multiprocessing import Process, Queue +from multiprocessing import Process +from datetime import datetime from pprint import pprint # 设置numexpr最大线程数,默认为CPU核心数 @@ -17,9 +18,9 @@ except: sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ - logger + logger, log_verbose, TEXT_SPLITTER from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER, - FSCHAT_OPENAI_API, ) + FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT) from server.utils import (fschat_controller_address, fschat_model_worker_address, fschat_openai_api_address, set_httpx_timeout, get_model_worker_config, get_all_model_worker_configs, @@ -47,7 +48,7 @@ def create_controller_app( return app -def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: +def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger @@ -87,6 +88,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse args.limit_worker_concurrency = 5 args.stream_interval = 2 args.no_register = False + args.embed_in_truncate = False for k, v in kwargs.items(): setattr(args, k, v) @@ -147,6 +149,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> Tuple[argparse awq_config=awq_config, stream_interval=args.stream_interval, conv_template=args.conv_template, + embed_in_truncate=args.embed_in_truncate, ) sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config @@ -188,29 +191,15 @@ def create_openai_api_app( return app -def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): - if q is None or not isinstance(run_seq, int): - return - - if run_seq == 1: - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - q.put(run_seq) - elif run_seq > 1: - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - while True: - no = q.get() - if no != run_seq - 1: - q.put(no) - else: - break - q.put(run_seq) +def _set_app_event(app: FastAPI, started_event: mp.Event = None): + @app.on_event("startup") + async def on_startup(): + set_httpx_timeout() + if started_event is not None: + started_event.set() -def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Event = None): +def run_controller(log_level: str = "INFO", started_event: mp.Event = None): import uvicorn import httpx from fastapi import Body @@ -221,12 +210,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"), log_level=log_level, ) - _set_app_seq(app, q, run_seq) - - @app.on_event("startup") - def on_startup(): - if e is not None: - e.set() + _set_app_event(app, started_event) # add interface to release and load model worker @app.post("/release_worker") @@ -266,7 +250,7 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev return {"code": 500, "msg": msg} if new_model_name: - timer = 300 # wait 5 minutes for new model_worker register + timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register while timer > 0: models = app._controller.list_models() if new_model_name in models: @@ -299,9 +283,9 @@ def run_controller(q: Queue, run_seq: int = 1, log_level: str = "INFO", e: mp.Ev def run_model_worker( model_name: str = LLM_MODEL, controller_address: str = "", - q: Queue = None, - run_seq: int = 2, log_level: str = "INFO", + q: mp.Queue = None, + started_event: mp.Event = None, ): import uvicorn from fastapi import Body @@ -317,7 +301,7 @@ def run_model_worker( kwargs["model_path"] = model_path app = create_model_worker_app(log_level=log_level, **kwargs) - _set_app_seq(app, q, run_seq) + _set_app_event(app, started_event) if log_level == "ERROR": sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ @@ -325,29 +309,29 @@ def run_model_worker( # add interface to release and load model @app.post("/release") def release_model( - new_model_name: str = Body(None, description="释放后加载该模型"), - keep_origin: bool = Body(False, description="不释放原模型,加载新模型") + new_model_name: str = Body(None, description="释放后加载该模型"), + keep_origin: bool = Body(False, description="不释放原模型,加载新模型") ) -> Dict: if keep_origin: if new_model_name: - q.put(["start", new_model_name]) + q.put([model_name, "start", new_model_name]) else: if new_model_name: - q.put(["replace", new_model_name]) + q.put([model_name, "replace", new_model_name]) else: - q.put(["stop"]) + q.put([model_name, "stop", None]) return {"code": 200, "msg": "done"} uvicorn.run(app, host=host, port=port, log_level=log_level.lower()) -def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"): +def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None): import uvicorn import sys controller_addr = fschat_controller_address() app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. - _set_app_seq(app, q, run_seq) + _set_app_event(app, started_event) host = FSCHAT_OPENAI_API["host"] port = FSCHAT_OPENAI_API["port"] @@ -357,12 +341,12 @@ def run_openai_api(q: Queue, run_seq: int = 3, log_level: str = "INFO"): uvicorn.run(app, host=host, port=port) -def run_api_server(q: Queue, run_seq: int = 4): +def run_api_server(started_event: mp.Event = None): from server.api import create_app import uvicorn app = create_app() - _set_app_seq(app, q, run_seq) + _set_app_event(app, started_event) host = API_SERVER["host"] port = API_SERVER["port"] @@ -370,21 +354,19 @@ def run_api_server(q: Queue, run_seq: int = 4): uvicorn.run(app, host=host, port=port) -def run_webui(q: Queue, run_seq: int = 5): +def run_webui(started_event: mp.Event = None): host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] - if q is not None and isinstance(run_seq, int): - while True: - no = q.get() - if no != run_seq - 1: - q.put(no) - else: - break - q.put(run_seq) p = subprocess.Popen(["streamlit", "run", "webui.py", "--server.address", host, - "--server.port", str(port)]) + "--server.port", str(port), + "--theme.base", "light", + "--theme.primaryColor", "#165dff", + "--theme.secondaryBackgroundColor", "#f5f5f5", + "--theme.textColor", "#000000", + ]) + started_event.set() p.wait() @@ -427,8 +409,9 @@ def parse_args() -> argparse.ArgumentParser: "-n", "--model-name", type=str, - default=LLM_MODEL, - help="specify model name for model worker.", + nargs="+", + default=[LLM_MODEL], + help="specify model name for model worker. add addition names with space seperated to start multiple model workers.", dest="model_name", ) parser.add_argument( @@ -483,11 +466,15 @@ def dump_server_info(after_start=False, args=None): print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - model = LLM_MODEL + models = [LLM_MODEL] if args and args.model_name: - model = args.model_name - print(f"当前LLM模型:{model} @ {llm_device()}") - pprint(llm_model_dict[model]) + models = args.model_name + + print(f"当前使用的分词器:{TEXT_SPLITTER}") + print(f"当前启动的LLM模型:{models} @ {llm_device()}") + + for model in models: + pprint(llm_model_dict[model]) print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") if after_start: @@ -554,12 +541,12 @@ async def start_main_server(): logger.info(f"正在启动服务:") logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") - processes = {"online-api": []} + processes = {"online_api": {}, "model_worker": {}} def process_count(): - return len(processes) + len(processes["online-api"]) - 1 + return len(processes) + len(processes["online_api"]) + len(processes["model_worker"]) - 2 - if args.quiet: + if args.quiet or not log_verbose: log_level = "ERROR" else: log_level = "INFO" @@ -569,63 +556,73 @@ async def start_main_server(): process = Process( target=run_controller, name=f"controller", - args=(queue, process_count() + 1, log_level, controller_started), + kwargs=dict(log_level=log_level, started_event=controller_started), daemon=True, ) - processes["controller"] = process process = Process( target=run_openai_api, name=f"openai_api", - args=(queue, process_count() + 1), daemon=True, ) processes["openai_api"] = process + model_worker_started = [] if args.model_worker: - config = get_model_worker_config(args.model_name) - if not config.get("online_api"): - process = Process( - target=run_model_worker, - name=f"model_worker - {args.model_name}", - args=(args.model_name, args.controller_address, queue, process_count() + 1, log_level), - daemon=True, - ) - - processes["model_worker"] = process + for model_name in args.model_name: + config = get_model_worker_config(model_name) + if not config.get("online_api"): + e = manager.Event() + model_worker_started.append(e) + process = Process( + target=run_model_worker, + name=f"model_worker - {model_name}", + kwargs=dict(model_name=model_name, + controller_address=args.controller_address, + log_level=log_level, + q=queue, + started_event=e), + daemon=True, + ) + processes["model_worker"][model_name] = process if args.api_worker: configs = get_all_model_worker_configs() for model_name, config in configs.items(): if config.get("online_api") and config.get("worker_class"): + e = manager.Event() + model_worker_started.append(e) process = Process( target=run_model_worker, - name=f"model_worker - {model_name}", - args=(model_name, args.controller_address, queue, process_count() + 1, log_level), + name=f"api_worker - {model_name}", + kwargs=dict(model_name=model_name, + controller_address=args.controller_address, + log_level=log_level, + q=queue, + started_event=e), daemon=True, ) + processes["online_api"][model_name] = process - processes["online-api"].append(process) - + api_started = manager.Event() if args.api: process = Process( target=run_api_server, name=f"API Server", - args=(queue, process_count() + 1), + kwargs=dict(started_event=api_started), daemon=True, ) - processes["api"] = process + webui_started = manager.Event() if args.webui: process = Process( target=run_webui, name=f"WEBUI Server", - args=(queue, process_count() + 1), + kwargs=dict(started_event=webui_started), daemon=True, ) - processes["webui"] = process if process_count() == 0: @@ -636,60 +633,106 @@ async def start_main_server(): if p:= processes.get("controller"): p.start() p.name = f"{p.name} ({p.pid})" - controller_started.wait() + controller_started.wait() # 等待controller启动完成 if p:= processes.get("openai_api"): p.start() p.name = f"{p.name} ({p.pid})" - if p:= processes.get("model_worker"): + for n, p in processes.get("model_worker", {}).items(): p.start() p.name = f"{p.name} ({p.pid})" - for p in processes.get("online-api", []): + for n, p in processes.get("online_api", []).items(): p.start() p.name = f"{p.name} ({p.pid})" + # 等待所有model_worker启动完成 + for e in model_worker_started: + e.wait() + if p:= processes.get("api"): p.start() p.name = f"{p.name} ({p.pid})" + api_started.wait() # 等待api.py启动完成 if p:= processes.get("webui"): p.start() p.name = f"{p.name} ({p.pid})" + webui_started.wait() # 等待webui.py启动完成 + + dump_server_info(after_start=True, args=args) while True: - no = queue.get() - if no == process_count(): - time.sleep(0.5) - dump_server_info(after_start=True, args=args) - break - else: - queue.put(no) + cmd = queue.get() # 收到切换模型的消息 + e = manager.Event() + if isinstance(cmd, list): + model_name, cmd, new_model_name = cmd + if cmd == "start": # 运行新模型 + logger.info(f"准备启动新模型进程:{new_model_name}") + process = Process( + target=run_model_worker, + name=f"model_worker - {new_model_name}", + kwargs=dict(model_name=new_model_name, + controller_address=args.controller_address, + log_level=log_level, + q=queue, + started_event=e), + daemon=True, + ) + process.start() + process.name = f"{process.name} ({process.pid})" + processes["model_worker"][new_model_name] = process + e.wait() + logger.info(f"成功启动新模型进程:{new_model_name}") + elif cmd == "stop": + if process := processes["model_worker"].get(model_name): + time.sleep(1) + process.terminate() + process.join() + logger.info(f"停止模型进程:{model_name}") + else: + logger.error(f"未找到模型进程:{model_name}") + elif cmd == "replace": + if process := processes["model_worker"].pop(model_name, None): + logger.info(f"停止模型进程:{model_name}") + start_time = datetime.now() + time.sleep(1) + process.terminate() + process.join() + process = Process( + target=run_model_worker, + name=f"model_worker - {new_model_name}", + kwargs=dict(model_name=new_model_name, + controller_address=args.controller_address, + log_level=log_level, + q=queue, + started_event=e), + daemon=True, + ) + process.start() + process.name = f"{process.name} ({process.pid})" + processes["model_worker"][new_model_name] = process + e.wait() + timing = datetime.now() - start_time + logger.info(f"成功启动新模型进程:{new_model_name}。用时:{timing}。") + else: + logger.error(f"未找到模型进程:{model_name}") - if model_worker_process := processes.get("model_worker"): - model_worker_process.join() - for process in processes.get("online-api", []): - process.join() - for name, process in processes.items(): - if name not in ["model_worker", "online-api"]: - if isinstance(p, list): - for work_process in p: - work_process.join() - else: - process.join() + + # for process in processes.get("model_worker", {}).values(): + # process.join() + # for process in processes.get("online_api", {}).values(): + # process.join() + + # for name, process in processes.items(): + # if name not in ["model_worker", "online_api"]: + # if isinstance(p, dict): + # for work_process in p.values(): + # work_process.join() + # else: + # process.join() except Exception as e: - # if model_worker_process := processes.pop("model_worker", None): - # model_worker_process.terminate() - # for process in processes.pop("online-api", []): - # process.terminate() - # for process in processes.values(): - # - # if isinstance(process, list): - # for work_process in process: - # work_process.terminate() - # else: - # process.terminate() logger.error(e) logger.warning("Caught KeyboardInterrupt! Setting stop event...") finally: @@ -702,10 +745,9 @@ async def start_main_server(): # Queues and other inter-process communication primitives can break when # process is killed, but we don't care here - if isinstance(p, list): - for process in p: + if isinstance(p, dict): + for process in p.values(): process.kill() - else: p.kill() diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 51bbac19..ed4e8b21 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -7,19 +7,23 @@ root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) from server.utils import api_address from configs.model_config import VECTOR_SEARCH_TOP_K -from server.knowledge_base.utils import get_kb_path +from server.knowledge_base.utils import get_kb_path, get_file_path from pprint import pprint api_base_url = api_address() + kb = "kb_for_api_test" test_files = { + "FAQ.MD": str(root_path / "docs" / "FAQ.MD"), "README.MD": str(root_path / "README.MD"), - "FAQ.MD": str(root_path / "docs" / "FAQ.MD") + "test.txt": get_file_path("samples", "test.txt"), } +print("\n\n直接url访问\n") + def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): if not Path(get_kb_path(kb)).exists(): @@ -78,37 +82,36 @@ def test_list_kbs(api="/knowledge_base/list_knowledge_bases"): assert kb in data["data"] -def test_upload_doc(api="/knowledge_base/upload_doc"): +def test_upload_docs(api="/knowledge_base/upload_docs"): url = api_base_url + api - for name, path in test_files.items(): - print(f"\n上传知识文件: {name}") - data = {"knowledge_base_name": kb, "override": True} - files = {"file": (name, open(path, "rb"))} - r = requests.post(url, data=data, files=files) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"成功上传文件 {name}" + files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()] - for name, path in test_files.items(): - print(f"\n尝试重新上传知识文件: {name}, 不覆盖") - data = {"knowledge_base_name": kb, "override": False} - files = {"file": (name, open(path, "rb"))} - r = requests.post(url, data=data, files=files) - data = r.json() - pprint(data) - assert data["code"] == 404 - assert data["msg"] == f"文件 {name} 已存在。" + print(f"\n上传知识文件") + data = {"knowledge_base_name": kb, "override": True} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 - for name, path in test_files.items(): - print(f"\n尝试重新上传知识文件: {name}, 覆盖") - data = {"knowledge_base_name": kb, "override": True} - files = {"file": (name, open(path, "rb"))} - r = requests.post(url, data=data, files=files) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"成功上传文件 {name}" + print(f"\n尝试重新上传知识文件, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()] + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == len(test_files) + + print(f"\n尝试重新上传知识文件, 覆盖,自定义docs") + docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]} + data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)} + files = [("files", (name, open(path, "rb"))) for name, path in test_files.items()] + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 def test_list_files(api="/knowledge_base/list_files"): @@ -134,26 +137,26 @@ def test_search_docs(api="/knowledge_base/search_docs"): assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K -def test_update_doc(api="/knowledge_base/update_doc"): +def test_update_docs(api="/knowledge_base/update_docs"): url = api_base_url + api - for name, path in test_files.items(): - print(f"\n更新知识文件: {name}") - r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name}) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"成功更新文件 {name}" + + print(f"\n更新知识文件") + r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 -def test_delete_doc(api="/knowledge_base/delete_doc"): +def test_delete_docs(api="/knowledge_base/delete_docs"): url = api_base_url + api - for name, path in test_files.items(): - print(f"\n删除知识文件: {name}") - r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name}) - data = r.json() - pprint(data) - assert data["code"] == 200 - assert data["msg"] == f"{name} 文件删除成功" + + print(f"\n删除知识文件") + r = requests.post(url, json={"knowledge_base_name": kb, "file_names": list(test_files)}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 url = api_base_url + "/knowledge_base/search_docs" query = "介绍一下langchain-chatchat项目" diff --git a/tests/api/test_kb_api_request.py b/tests/api/test_kb_api_request.py new file mode 100644 index 00000000..86455282 --- /dev/null +++ b/tests/api/test_kb_api_request.py @@ -0,0 +1,161 @@ +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from server.utils import api_address +from configs.model_config import VECTOR_SEARCH_TOP_K +from server.knowledge_base.utils import get_kb_path, get_file_path +from webui_pages.utils import ApiRequest + +from pprint import pprint + + +api_base_url = api_address() +api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False) + + +kb = "kb_for_api_test" +test_files = { + "FAQ.MD": str(root_path / "docs" / "FAQ.MD"), + "README.MD": str(root_path / "README.MD"), + "test.txt": get_file_path("samples", "test.txt"), +} + +print("\n\nApiRquest调用\n") + + +def test_delete_kb_before(): + if not Path(get_kb_path(kb)).exists(): + return + + data = api.delete_knowledge_base(kb) + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] + + +def test_create_kb(): + print(f"\n尝试用空名称创建知识库:") + data = api.create_knowledge_base(" ") + pprint(data) + assert data["code"] == 404 + assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称" + + print(f"\n创建新知识库: {kb}") + data = api.create_knowledge_base(kb) + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"已新增知识库 {kb}" + + print(f"\n尝试创建同名知识库: {kb}") + data = api.create_knowledge_base(kb) + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"已存在同名知识库 {kb}" + + +def test_list_kbs(): + data = api.list_knowledge_bases() + pprint(data) + assert isinstance(data, list) and len(data) > 0 + assert kb in data + + +def test_upload_docs(): + files = list(test_files.values()) + + print(f"\n上传知识文件") + data = {"knowledge_base_name": kb, "override": True} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + print(f"\n尝试重新上传知识文件, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == len(test_files) + + print(f"\n尝试重新上传知识文件, 覆盖,自定义docs") + docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]} + data = {"knowledge_base_name": kb, "override": True, "docs": docs} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + +def test_list_files(): + print("\n获取知识库中文件列表:") + data = api.list_kb_docs(knowledge_base_name=kb) + pprint(data) + assert isinstance(data, list) + for name in test_files: + assert name in data + + +def test_search_docs(): + query = "介绍一下langchain-chatchat项目" + print("\n检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_update_docs(): + print(f"\n更新知识文件") + data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files)) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + +def test_delete_docs(): + print(f"\n删除知识文件") + data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files)) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + query = "介绍一下langchain-chatchat项目" + print("\n尝试检索删除后的检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == 0 + + +def test_recreate_vs(): + print("\n重建知识库:") + r = api.recreate_vector_store(kb) + for data in r: + assert isinstance(data, dict) + assert data["code"] == 200 + print(data["msg"]) + + query = "本项目支持哪些文件格式?" + print("\n尝试检索重建后的检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_delete_kb_after(): + print("\n删除知识库") + data = api.delete_knowledge_base(kb) + pprint(data) + + # check kb not exists anymore + print("\n获取知识库列表:") + data = api.list_knowledge_bases() + pprint(data) + assert isinstance(data, list) and len(data) > 0 + assert kb not in data diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py index f348fe74..af5ced8f 100644 --- a/tests/api/test_llm_api.py +++ b/tests/api/test_llm_api.py @@ -5,8 +5,9 @@ from pathlib import Path root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) -from configs.server_config import api_address, FSCHAT_MODEL_WORKERS +from configs.server_config import FSCHAT_MODEL_WORKERS from configs.model_config import LLM_MODEL, llm_model_dict +from server.utils import api_address, get_model_worker_config from pprint import pprint import random @@ -64,7 +65,8 @@ def test_change_model(api="/llm_model/change"): assert len(availabel_new_models) > 0 print(availabel_new_models) - model_name = random.choice(running_models) + local_models = [x for x in running_models if not get_model_worker_config(x).get("online_api")] + model_name = random.choice(local_models) new_model_name = random.choice(availabel_new_models) print(f"\n尝试将模型从 {model_name} 切换到 {new_model_name}") r = requests.post(url, json={"model_name": model_name, "new_model_name": new_model_name}) diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index 4c2d5faf..14314853 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -43,11 +43,11 @@ data = { "content": "你好,我是人工智能大模型" } ], - "stream": True + "stream": True, + "temperature": 0.7, } - def test_chat_fastchat(api="/chat/fastchat"): url = f"{api_base_url}{api}" data2 = { @@ -89,14 +89,12 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"): response = requests.post(url, headers=headers, json=data, stream=True) print("\n") print("=" * 30 + api + " output" + "="*30) - first = True for line in response.iter_content(None, decode_unicode=True): data = json.loads(line) - if first: - for doc in data["docs"]: - print(doc) - first = False - print(data["answer"], end="", flush=True) + if "anser" in data: + print(data["answer"], end="", flush=True) + assert "docs" in data and len(data["docs"]) > 0 + pprint(data["docs"]) assert response.status_code == 200 @@ -117,14 +115,11 @@ def test_search_engine_chat(api="/chat/search_engine_chat"): print("\n") print("=" * 30 + api + " by {se} output" + "="*30) - first = True for line in response.iter_content(None, decode_unicode=True): data = json.loads(line) - assert "docs" in data and len(data["docs"]) > 0 - if first: - for doc in data.get("docs", []): - print(doc) - first = False - print(data["answer"], end="", flush=True) + if "answer" in data: + print(data["answer"], end="", flush=True) + assert "docs" in data and len(data["docs"]) > 0 + pprint(data["docs"]) assert response.status_code == 200 diff --git a/tests/custom_splitter/test_different_splitter.py b/tests/custom_splitter/test_different_splitter.py new file mode 100644 index 00000000..fea597e7 --- /dev/null +++ b/tests/custom_splitter/test_different_splitter.py @@ -0,0 +1,53 @@ +import os + +from transformers import AutoTokenizer +import sys + +sys.path.append("../..") +from configs.model_config import ( + CHUNK_SIZE, + OVERLAP_SIZE +) + +from server.knowledge_base.utils import make_text_splitter + +def text(splitter_name): + from langchain import document_loaders + + # 使用DocumentLoader读取文件 + filepath = "../../knowledge_base/samples/content/test.txt" + loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True) + docs = loader.load() + text_splitter = make_text_splitter(splitter_name, CHUNK_SIZE, OVERLAP_SIZE) + if splitter_name == "MarkdownHeaderTextSplitter": + docs = text_splitter.split_text(docs[0].page_content) + for doc in docs: + if doc.metadata: + doc.metadata["source"] = os.path.basename(filepath) + else: + docs = text_splitter.split_documents(docs) + for doc in docs: + print(doc) + return docs + + + + +import pytest +from langchain.docstore.document import Document + +@pytest.mark.parametrize("splitter_name", + [ + "ChineseRecursiveTextSplitter", + "SpacyTextSplitter", + "RecursiveCharacterTextSplitter", + "MarkdownHeaderTextSplitter" + ]) +def test_different_splitter(splitter_name): + try: + docs = text(splitter_name) + assert isinstance(docs, list) + if len(docs)>0: + assert isinstance(docs[0], Document) + except Exception as e: + pytest.fail(f"test_different_splitter failed with {splitter_name}, error: {str(e)}") diff --git a/tests/document_loader/test_imgloader.py b/tests/document_loader/test_imgloader.py index 8bba7da9..92460cb4 100644 --- a/tests/document_loader/test_imgloader.py +++ b/tests/document_loader/test_imgloader.py @@ -6,14 +6,14 @@ sys.path.append(str(root_path)) from pprint import pprint test_files = { - "ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"), + "ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"), } -def test_rapidocrpdfloader(): - pdf_path = test_files["ocr_test.pdf"] - from document_loaders import RapidOCRPDFLoader +def test_rapidocrloader(): + img_path = test_files["ocr_test.jpg"] + from document_loaders import RapidOCRLoader - loader = RapidOCRPDFLoader(pdf_path) + loader = RapidOCRLoader(img_path) docs = loader.load() pprint(docs) assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str) diff --git a/tests/document_loader/test_pdfloader.py b/tests/document_loader/test_pdfloader.py index 92460cb4..8bba7da9 100644 --- a/tests/document_loader/test_pdfloader.py +++ b/tests/document_loader/test_pdfloader.py @@ -6,14 +6,14 @@ sys.path.append(str(root_path)) from pprint import pprint test_files = { - "ocr_test.jpg": str(root_path / "tests" / "samples" / "ocr_test.jpg"), + "ocr_test.pdf": str(root_path / "tests" / "samples" / "ocr_test.pdf"), } -def test_rapidocrloader(): - img_path = test_files["ocr_test.jpg"] - from document_loaders import RapidOCRLoader +def test_rapidocrpdfloader(): + pdf_path = test_files["ocr_test.pdf"] + from document_loaders import RapidOCRPDFLoader - loader = RapidOCRLoader(img_path) + loader = RapidOCRPDFLoader(pdf_path) docs = loader.load() pprint(docs) assert isinstance(docs, list) and len(docs) > 0 and isinstance(docs[0].page_content, str) diff --git a/tests/online_api/test_qianfan.py b/tests/online_api/test_qianfan.py new file mode 100644 index 00000000..0e8a9487 --- /dev/null +++ b/tests/online_api/test_qianfan.py @@ -0,0 +1,20 @@ +import sys +from pathlib import Path +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) + +from server.model_workers.qianfan import request_qianfan_api, MODEL_VERSIONS +from pprint import pprint +import pytest + + +@pytest.mark.parametrize("version", MODEL_VERSIONS.keys()) +def test_qianfan(version): + messages = [{"role": "user", "content": "你好"}] + print("\n" + version + "\n") + i = 1 + for x in request_qianfan_api(messages, version=version): + pprint(x) + assert isinstance(x, dict) + assert "error_code" not in x + i += 1 diff --git a/text_splitter/__init__.py b/text_splitter/__init__.py index 8f13f168..dc064120 100644 --- a/text_splitter/__init__.py +++ b/text_splitter/__init__.py @@ -1,3 +1,4 @@ from .chinese_text_splitter import ChineseTextSplitter from .ali_text_splitter import AliTextSplitter from .zh_title_enhance import zh_title_enhance +from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter \ No newline at end of file diff --git a/text_splitter/chinese_recursive_text_splitter.py b/text_splitter/chinese_recursive_text_splitter.py new file mode 100644 index 00000000..70b4b29c --- /dev/null +++ b/text_splitter/chinese_recursive_text_splitter.py @@ -0,0 +1,104 @@ +import re +from typing import List, Optional, Any +from langchain.text_splitter import RecursiveCharacterTextSplitter +import logging + +logger = logging.getLogger(__name__) + + +def _split_text_with_regex_from_end( + text: str, separator: str, keep_separator: bool +) -> List[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] + if len(_splits) % 2 == 1: + splits += _splits[-1:] + # splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = True, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or [ + "\n\n", + "\n", + "。|!|?", + "\.\s|\!\s|\?\s", + ";|;\s", + ",|,\s" + ] + self._is_separator_regex = is_separator_regex + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1:] + break + + _separator = separator if self._is_separator_regex else re.escape(separator) + splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""] + + +if __name__ == "__main__": + text_splitter = ChineseRecursiveTextSplitter( + keep_separator=True, + is_separator_regex=True, + chunk_size=50, + chunk_overlap=0 + ) + ls = [ + """中国对外贸易形势报告(75页)。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1%, 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点;进口8.9万亿元,增长24.9%,占进口总额的62.7%, 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8%, 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""", + ] + # text = """""" + for inum, text in enumerate(ls): + print(inum) + chunks = text_splitter.split_text(text) + for chunk in chunks: + print(chunk) diff --git a/text_splitter/chinese_text_splitter.py b/text_splitter/chinese_text_splitter.py index d6294ae8..4107b25f 100644 --- a/text_splitter/chinese_text_splitter.py +++ b/text_splitter/chinese_text_splitter.py @@ -1,11 +1,10 @@ from langchain.text_splitter import CharacterTextSplitter import re from typing import List -from configs.model_config import CHUNK_SIZE class ChineseTextSplitter(CharacterTextSplitter): - def __init__(self, pdf: bool = False, sentence_size: int = CHUNK_SIZE, **kwargs): + def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): super().__init__(**kwargs) self.pdf = pdf self.sentence_size = sentence_size diff --git a/webui.py b/webui.py index 0cda9ebc..2750c477 100644 --- a/webui.py +++ b/webui.py @@ -1,9 +1,3 @@ -# 运行方式: -# 1. 安装必要的包:pip install streamlit-option-menu streamlit-chatbox>=1.1.6 -# 2. 运行本机fastchat服务:python server\llm_api.py 或者 运行对应的sh文件 -# 3. 运行API服务器:python server/api.py。如果使用api = ApiRequest(no_remote_api=True),该步可以跳过。 -# 4. 运行WEB UI:streamlit run webui.py --server.port 7860 - import streamlit as st from webui_pages.utils import * from streamlit_option_menu import option_menu diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index afefe32a..4b347df0 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -5,7 +5,7 @@ from streamlit_chatbox import * from datetime import datetime from server.chat.search_engine_chat import SEARCH_ENGINES import os -from configs.model_config import llm_model_dict, LLM_MODEL +from configs.model_config import LLM_MODEL, TEMPERATURE from server.utils import get_model_worker_config from typing import List, Dict @@ -55,17 +55,20 @@ def dialogue_page(api: ApiRequest): st.toast(text) # sac.alert(text, description="descp", type="success", closable=True, banner=True) - dialogue_mode = st.selectbox("请选择对话模式", + dialogue_mode = st.selectbox("请选择对话模式:", ["LLM 对话", "知识库问答", "搜索引擎问答", ], + index=1, on_change=on_mode_change, key="dialogue_mode", ) def on_llm_change(): - st.session_state["prev_llm_model"] = llm_model + config = get_model_worker_config(llm_model) + if not config.get("online_api"): # 只有本地model_worker可以切换模型 + st.session_state["prev_llm_model"] = llm_model def llm_model_format_func(x): if x in running_models: @@ -78,10 +81,8 @@ def dialogue_page(api: ApiRequest): if x in config_models: config_models.remove(x) llm_models = running_models + config_models - if "prev_llm_model" not in st.session_state: - index = llm_models.index(LLM_MODEL) - else: - index = 0 + cur_model = st.session_state.get("cur_llm_model", LLM_MODEL) + index = llm_models.index(cur_model) llm_model = st.selectbox("选择LLM模型:", llm_models, index, @@ -91,10 +92,11 @@ def dialogue_page(api: ApiRequest): ) if (st.session_state.get("prev_llm_model") != llm_model and not get_model_worker_config(llm_model).get("online_api")): - with st.spinner(f"正在加载模型: {llm_model}"): + with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model) - st.session_state["prev_llm_model"] = llm_model + st.session_state["cur_llm_model"] = llm_model + temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05) history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN) def on_kb_change(): @@ -110,7 +112,7 @@ def dialogue_page(api: ApiRequest): key="selected_kb", ) kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) - score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) + score_threshold = st.slider("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) # chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) elif dialogue_mode == "搜索引擎问答": @@ -135,7 +137,7 @@ def dialogue_page(api: ApiRequest): if dialogue_mode == "LLM 对话": chat_box.ai_say("正在思考...") text = "" - r = api.chat_chat(prompt, history=history, model=llm_model) + r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature) for t in r: if error_msg := check_error_msg(t): # check whether error occured st.error(error_msg) @@ -150,28 +152,38 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="知识库匹配结果"), ]) text = "" - for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model): - if error_msg := check_error_msg(d): # check whether error occured + for d in api.knowledge_base_chat(prompt, + knowledge_base_name=selected_kb, + top_k=kb_top_k, + score_threshold=score_threshold, + history=history, + model=llm_model, + temperature=temperature): + if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) - else: - text += d["answer"] + elif chunk := d.get("answer"): + text += chunk chat_box.update_msg(text, 0) - chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) chat_box.update_msg(text, 0, streaming=False) + chat_box.update_msg("\n\n".join(d.get("docs", [])), 1, streaming=False) elif dialogue_mode == "搜索引擎问答": chat_box.ai_say([ f"正在执行 `{search_engine}` 搜索...", Markdown("...", in_expander=True, title="网络搜索结果"), ]) text = "" - for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model): - if error_msg := check_error_msg(d): # check whether error occured + for d in api.search_engine_chat(prompt, + search_engine_name=search_engine, + top_k=se_top_k, + model=llm_model, + temperature=temperature): + if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) - else: - text += d["answer"] + elif chunk := d.get("answer"): + text += chunk chat_box.update_msg(text, 0) - chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) chat_box.update_msg(text, 0, streaming=False) + chat_box.update_msg("\n\n".join(d.get("docs", [])), 1, streaming=False) now = datetime.now() with st.sidebar: diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 29a63225..c71da7e4 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -6,7 +6,9 @@ import pandas as pd from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from typing import Literal, Dict, Tuple -from configs.model_config import embedding_model_dict, kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE +from configs.model_config import (embedding_model_dict, kbs_config, + EMBEDDING_MODEL, DEFAULT_VS_TYPE, + CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) import os import time @@ -125,28 +127,41 @@ def knowledge_base_page(api: ApiRequest): elif selected_kb: kb = selected_kb + # 上传文件 - # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) - files = st.file_uploader("上传知识文件", + files = st.file_uploader("上传知识文件:", [i for ls in LOADER_DICT.values() for i in ls], accept_multiple_files=True, ) + + # with st.sidebar: + with st.expander( + "文件处理配置", + expanded=True, + ): + cols = st.columns(3) + chunk_size = cols[0].number_input("单段文本最大长度:", 1, 1000, CHUNK_SIZE) + chunk_overlap = cols[1].number_input("相邻文本重合长度:", 0, chunk_size, OVERLAP_SIZE) + cols[2].write("") + cols[2].write("") + zh_title_enhance = cols[2].checkbox("开启中文标题加强", ZH_TITLE_ENHANCE) + if st.button( "添加文件到知识库", - # help="请先上传文件,再点击添加", # use_container_width=True, disabled=len(files) == 0, ): - data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] - data[-1]["not_refresh_vs_cache"]=False - for k in data: - ret = api.upload_kb_doc(**k) - if msg := check_success_msg(ret): - st.toast(msg, icon="✔") - elif msg := check_error_msg(ret): - st.toast(msg, icon="✖") - st.session_state.files = [] + ret = api.upload_kb_docs(files, + knowledge_base_name=kb, + override=True, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance) + if msg := check_success_msg(ret): + st.toast(msg, icon="✔") + elif msg := check_error_msg(ret): + st.toast(msg, icon="✖") st.divider() @@ -160,7 +175,7 @@ def knowledge_base_page(api: ApiRequest): st.info("知识库中包含源文件与向量库,请从下表中选择文件后操作") doc_details.drop(columns=["kb_name"], inplace=True) doc_details = doc_details[[ - "No", "file_name", "document_loader", "docs_count", "in_folder", "in_db", + "No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db", ]] # doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") # doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") @@ -173,7 +188,7 @@ def knowledge_base_page(api: ApiRequest): # ("file_version", "文档版本"): {}, ("document_loader", "文档加载器"): {}, ("docs_count", "文档数量"): {}, - # ("text_splitter", "分词器"): {}, + ("text_splitter", "分词器"): {}, # ("create_time", "创建时间"): {}, ("in_folder", "源文件"): {"cellRenderer": cell_renderer}, ("in_db", "向量库"): {"cellRenderer": cell_renderer}, @@ -218,8 +233,12 @@ def knowledge_base_page(api: ApiRequest): disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): - for row in selected_rows: - api.update_kb_doc(kb, row["file_name"]) + file_names = [row["file_name"] for row in selected_rows] + api.update_kb_docs(kb, + file_names=file_names, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance) st.experimental_rerun() # 将文件从向量库中删除,但不删除文件本身。 @@ -228,8 +247,8 @@ def knowledge_base_page(api: ApiRequest): disabled=not (selected_rows and selected_rows[0]["in_db"]), use_container_width=True, ): - for row in selected_rows: - api.delete_kb_doc(kb, row["file_name"]) + file_names = [row["file_name"] for row in selected_rows] + api.delete_kb_docs(kb, file_names=file_names) st.experimental_rerun() if cols[3].button( @@ -237,9 +256,8 @@ def knowledge_base_page(api: ApiRequest): type="primary", use_container_width=True, ): - for row in selected_rows: - ret = api.delete_kb_doc(kb, row["file_name"], True) - st.toast(ret.get("msg", " ")) + file_names = [row["file_name"] for row in selected_rows] + api.delete_kb_docs(kb, file_names=file_names, delete_content=True) st.experimental_rerun() st.divider() @@ -255,7 +273,10 @@ def knowledge_base_page(api: ApiRequest): with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): empty = st.empty() empty.progress(0.0, "") - for d in api.recreate_vector_store(kb): + for d in api.recreate_vector_store(kb, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance): if msg := check_error_msg(d): st.toast(msg) else: diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 08511044..26e53206 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -8,10 +8,14 @@ from configs.model_config import ( LLM_MODEL, llm_model_dict, HISTORY_LEN, + TEMPERATURE, SCORE_THRESHOLD, + CHUNK_SIZE, + OVERLAP_SIZE, + ZH_TITLE_ENHANCE, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, - logger, + logger, log_verbose, ) from configs.server_config import HTTPX_DEFAULT_TIMEOUT import httpx @@ -20,10 +24,9 @@ from server.chat.openai_chat import OpenAiChatMsgIn from fastapi.responses import StreamingResponse import contextlib import json +import os from io import BytesIO -from server.db.repository.knowledge_base_repository import get_kb_detail -from server.db.repository.knowledge_file_repository import get_file_detail -from server.utils import run_async, iter_over_async, set_httpx_timeout +from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address from configs.model_config import NLTK_DATA_PATH import nltk @@ -43,13 +46,15 @@ class ApiRequest: ''' def __init__( self, - base_url: str = "http://127.0.0.1:7861", + base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT, no_remote_api: bool = False, # call api view function directly ): self.base_url = base_url self.timeout = timeout self.no_remote_api = no_remote_api + if no_remote_api: + logger.warn("将来可能取消对no_remote_api的支持,更新版本时请注意。") def _parse_url(self, url: str) -> str: if (not url.startswith("http") @@ -78,7 +83,9 @@ class ApiRequest: else: return httpx.get(url, params=params, **kwargs) except Exception as e: - logger.error(e) + msg = f"error when get {url}: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) retry -= 1 async def aget( @@ -99,7 +106,9 @@ class ApiRequest: else: return await client.get(url, params=params, **kwargs) except Exception as e: - logger.error(e) + msg = f"error when aget {url}: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) retry -= 1 def post( @@ -121,7 +130,9 @@ class ApiRequest: else: return httpx.post(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + msg = f"error when post {url}: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) retry -= 1 async def apost( @@ -143,7 +154,9 @@ class ApiRequest: else: return await client.post(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + msg = f"error when apost {url}: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) retry -= 1 def delete( @@ -164,7 +177,9 @@ class ApiRequest: else: return httpx.delete(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + msg = f"error when delete {url}: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) retry -= 1 async def adelete( @@ -186,7 +201,9 @@ class ApiRequest: else: return await client.delete(url, data=data, json=json, **kwargs) except Exception as e: - logger.error(e) + msg = f"error when adelete {url}: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) retry -= 1 def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): @@ -197,7 +214,7 @@ class ApiRequest: loop = asyncio.get_event_loop() except: loop = asyncio.new_event_loop() - + try: for chunk in iter_over_async(response.body_iterator, loop): if as_json and chunk: @@ -205,7 +222,9 @@ class ApiRequest: elif chunk.strip(): yield chunk except Exception as e: - logger.error(e) + msg = f"error when run fastapi router: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) def _httpx_stream2generator( self, @@ -226,23 +245,26 @@ class ApiRequest: pprint(data, depth=1) yield data except Exception as e: - logger.error(f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。") + msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) else: print(chunk, end="", flush=True) yield chunk except httpx.ConnectError as e: - msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。" + msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" + logger.error(msg) logger.error(msg) - logger.error(e) yield {"code": 500, "msg": msg} except httpx.ReadTimeout as e: - msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')" + msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')。({e})" logger.error(msg) - logger.error(e) yield {"code": 500, "msg": msg} except Exception as e: - logger.error(e) - yield {"code": 500, "msg": str(e)} + msg = f"API通信遇到错误:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + yield {"code": 500, "msg": msg} # 对话相关操作 @@ -251,7 +273,7 @@ class ApiRequest: messages: List[Dict], stream: bool = True, model: str = LLM_MODEL, - temperature: float = 0.7, + temperature: float = TEMPERATURE, max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens no_remote_api: bool = None, **kwargs: Any, @@ -272,7 +294,7 @@ class ApiRequest: if no_remote_api: from server.chat.openai_chat import openai_chat - response = openai_chat(msg) + response = run_async(openai_chat(msg)) return self._fastapi_stream2generator(response) else: data = msg.dict(exclude_unset=True, exclude_none=True) @@ -282,7 +304,7 @@ class ApiRequest: response = self.post( "/chat/fastchat", json=data, - stream=stream, + stream=True, ) return self._httpx_stream2generator(response) @@ -292,6 +314,7 @@ class ApiRequest: history: List[Dict] = [], stream: bool = True, model: str = LLM_MODEL, + temperature: float = TEMPERATURE, no_remote_api: bool = None, ): ''' @@ -305,6 +328,7 @@ class ApiRequest: "history": history, "stream": stream, "model_name": model, + "temperature": temperature, } print(f"received input message:") @@ -312,7 +336,7 @@ class ApiRequest: if no_remote_api: from server.chat.chat import chat - response = chat(**data) + response = run_async(chat(**data)) return self._fastapi_stream2generator(response) else: response = self.post("/chat/chat", json=data, stream=True) @@ -327,6 +351,7 @@ class ApiRequest: history: List[Dict] = [], stream: bool = True, model: str = LLM_MODEL, + temperature: float = TEMPERATURE, no_remote_api: bool = None, ): ''' @@ -343,6 +368,7 @@ class ApiRequest: "history": history, "stream": stream, "model_name": model, + "temperature": temperature, "local_doc_url": no_remote_api, } @@ -351,7 +377,7 @@ class ApiRequest: if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat - response = knowledge_base_chat(**data) + response = run_async(knowledge_base_chat(**data)) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( @@ -368,6 +394,7 @@ class ApiRequest: top_k: int = SEARCH_ENGINE_TOP_K, stream: bool = True, model: str = LLM_MODEL, + temperature: float = TEMPERATURE, no_remote_api: bool = None, ): ''' @@ -382,6 +409,7 @@ class ApiRequest: "top_k": top_k, "stream": stream, "model_name": model, + "temperature": temperature, } print(f"received input message:") @@ -389,7 +417,7 @@ class ApiRequest: if no_remote_api: from server.chat.search_engine_chat import search_engine_chat - response = search_engine_chat(**data) + response = run_async(search_engine_chat(**data)) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( @@ -413,8 +441,10 @@ class ApiRequest: try: return response.json() except Exception as e: - logger.error(e) - return {"code": 500, "msg": errorMsg or str(e)} + msg = "API未能返回正确的JSON。" + (errorMsg or str(e)) + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + return {"code": 500, "msg": msg} def list_knowledge_bases( self, @@ -428,7 +458,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_api import list_kbs - response = run_async(list_kbs()) + response = list_kbs() return response.data else: response = self.get("/knowledge_base/list_knowledge_bases") @@ -456,7 +486,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_api import create_kb - response = run_async(create_kb(**data)) + response = create_kb(**data) return response.dict() else: response = self.post( @@ -478,7 +508,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_api import delete_kb - response = run_async(delete_kb(knowledge_base_name)) + response = delete_kb(knowledge_base_name) return response.dict() else: response = self.post( @@ -500,7 +530,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_doc_api import list_files - response = run_async(list_files(knowledge_base_name)) + response = list_files(knowledge_base_name) return response.data else: response = self.get( @@ -510,12 +540,48 @@ class ApiRequest: data = self._check_httpx_json_response(response) return data.get("data", []) - def upload_kb_doc( + def search_kb_docs( self, - file: Union[str, Path, bytes], + query: str, + knowledge_base_name: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: int = SCORE_THRESHOLD, + no_remote_api: bool = None, + ) -> List: + ''' + 对应api.py/knowledge_base/search_docs接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "query": query, + "knowledge_base_name": knowledge_base_name, + "top_k": top_k, + "score_threshold": score_threshold, + } + + if no_remote_api: + from server.knowledge_base.kb_doc_api import search_docs + return search_docs(**data) + else: + response = self.post( + "/knowledge_base/search_docs", + json=data, + ) + data = self._check_httpx_json_response(response) + return data + + def upload_kb_docs( + self, + files: List[Union[str, Path, bytes]], knowledge_base_name: str, - filename: str = None, override: bool = False, + to_vector_store: bool = True, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, + docs: Dict = {}, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): @@ -525,97 +591,122 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api - if isinstance(file, bytes): # raw bytes - file = BytesIO(file) - elif hasattr(file, "read"): # a file io like object - filename = filename or file.name - else: # a local path - file = Path(file).absolute().open("rb") - filename = filename or file.name + def convert_file(file, filename=None): + if isinstance(file, bytes): # raw bytes + file = BytesIO(file) + elif hasattr(file, "read"): # a file io like object + filename = filename or file.name + else: # a local path + file = Path(file).absolute().open("rb") + filename = filename or os.path.split(file.name)[-1] + return filename, file + + files = [convert_file(file) for file in files] + data={ + "knowledge_base_name": knowledge_base_name, + "override": override, + "to_vector_store": to_vector_store, + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "zh_title_enhance": zh_title_enhance, + "docs": docs, + "not_refresh_vs_cache": not_refresh_vs_cache, + } if no_remote_api: - from server.knowledge_base.kb_doc_api import upload_doc + from server.knowledge_base.kb_doc_api import upload_docs from fastapi import UploadFile from tempfile import SpooledTemporaryFile - temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) - temp_file.write(file.read()) - temp_file.seek(0) - response = run_async(upload_doc( - UploadFile(file=temp_file, filename=filename), - knowledge_base_name, - override, - )) + upload_files = [] + for filename, file in files: + temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) + temp_file.write(file.read()) + temp_file.seek(0) + upload_files.append(UploadFile(file=temp_file, filename=filename)) + + response = upload_docs(upload_files, **data) return response.dict() else: + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( - "/knowledge_base/upload_doc", - data={ - "knowledge_base_name": knowledge_base_name, - "override": override, - "not_refresh_vs_cache": not_refresh_vs_cache, - }, - files={"file": (filename, file)}, + "/knowledge_base/upload_docs", + data=data, + files=[("files", (filename, file)) for filename, file in files], ) return self._check_httpx_json_response(response) - def delete_kb_doc( + def delete_kb_docs( self, knowledge_base_name: str, - doc_name: str, + file_names: List[str], delete_content: bool = False, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/delete_doc接口 + 对应api.py/knowledge_base/delete_docs接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api data = { "knowledge_base_name": knowledge_base_name, - "doc_name": doc_name, + "file_names": file_names, "delete_content": delete_content, "not_refresh_vs_cache": not_refresh_vs_cache, } if no_remote_api: - from server.knowledge_base.kb_doc_api import delete_doc - response = run_async(delete_doc(**data)) + from server.knowledge_base.kb_doc_api import delete_docs + response = delete_docs(**data) return response.dict() else: response = self.post( - "/knowledge_base/delete_doc", + "/knowledge_base/delete_docs", json=data, ) return self._check_httpx_json_response(response) - def update_kb_doc( + def update_kb_docs( self, knowledge_base_name: str, - file_name: str, + file_names: List[str], + override_custom_docs: bool = False, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, + docs: Dict = {}, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/update_doc接口 + 对应api.py/knowledge_base/update_docs接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "knowledge_base_name": knowledge_base_name, + "file_names": file_names, + "override_custom_docs": override_custom_docs, + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "zh_title_enhance": zh_title_enhance, + "docs": docs, + "not_refresh_vs_cache": not_refresh_vs_cache, + } if no_remote_api: - from server.knowledge_base.kb_doc_api import update_doc - response = run_async(update_doc(knowledge_base_name, file_name)) + from server.knowledge_base.kb_doc_api import update_docs + response = update_docs(**data) return response.dict() else: + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( - "/knowledge_base/update_doc", - json={ - "knowledge_base_name": knowledge_base_name, - "file_name": file_name, - "not_refresh_vs_cache": not_refresh_vs_cache, - }, + "/knowledge_base/update_docs", + json=data, ) return self._check_httpx_json_response(response) @@ -625,6 +716,9 @@ class ApiRequest: allow_empty_kb: bool = True, vs_type: str = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, + chunk_size=CHUNK_SIZE, + chunk_overlap=OVERLAP_SIZE, + zh_title_enhance=ZH_TITLE_ENHANCE, no_remote_api: bool = None, ): ''' @@ -638,11 +732,14 @@ class ApiRequest: "allow_empty_kb": allow_empty_kb, "vs_type": vs_type, "embed_model": embed_model, + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "zh_title_enhance": zh_title_enhance, } if no_remote_api: from server.knowledge_base.kb_doc_api import recreate_vector_store - response = run_async(recreate_vector_store(**data)) + response = recreate_vector_store(**data) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( @@ -653,14 +750,30 @@ class ApiRequest: ) return self._httpx_stream2generator(response, as_json=True) - def list_running_models(self, controller_address: str = None): + # LLM模型相关操作 + def list_running_models( + self, + controller_address: str = None, + no_remote_api: bool = None, + ): ''' 获取Fastchat中正运行的模型列表 ''' - r = self.post( - "/llm_model/list_models", - ) - return r.json().get("data", []) + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "controller_address": controller_address, + } + if no_remote_api: + from server.llm_api import list_llm_models + return list_llm_models(**data).data + else: + r = self.post( + "/llm_model/list_models", + json=data, + ) + return r.json().get("data", []) def list_config_models(self): ''' @@ -672,30 +785,43 @@ class ApiRequest: self, model_name: str, controller_address: str = None, + no_remote_api: bool = None, ): ''' 停止某个LLM模型。 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + data = { "model_name": model_name, "controller_address": controller_address, } - r = self.post( - "/llm_model/stop", - json=data, - ) - return r.json() - + + if no_remote_api: + from server.llm_api import stop_llm_model + return stop_llm_model(**data).dict() + else: + r = self.post( + "/llm_model/stop", + json=data, + ) + return r.json() + def change_llm_model( self, model_name: str, new_model_name: str, controller_address: str = None, + no_remote_api: bool = None, ): ''' 向fastchat controller请求切换LLM模型。 ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + if not model_name or not new_model_name: return @@ -724,12 +850,17 @@ class ApiRequest: "new_model_name": new_model_name, "controller_address": controller_address, } - r = self.post( - "/llm_model/change", - json=data, - timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model - ) - return r.json() + + if no_remote_api: + from server.llm_api import change_llm_model + return change_llm_model(**data).dict() + else: + r = self.post( + "/llm_model/change", + json=data, + timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model + ) + return r.json() def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: