publish 0.2.10 (#2797)

新功能:
- 优化 PDF 文件的 OCR,过滤无意义的小图片 by @liunux4odoo #2525
- 支持 Gemini 在线模型 by @yhfgyyf #2630
- 支持 GLM4 在线模型 by @zRzRzRzRzRzRzR
- elasticsearch更新https连接 by @xldistance #2390
- 增强对PPT、DOC知识库文件的OCR识别 by @596192804 #2013
- 更新 Agent 对话功能 by @zRzRzRzRzRzRzR
- 每次创建对象时从连接池获取连接,避免每次执行方法时都新建连接 by @Lijia0 #2480
- 实现 ChatOpenAI 判断token有没有超过模型的context上下文长度 by @glide-the
- 更新运行数据库报错和项目里程碑 by @zRzRzRzRzRzRzR #2659
- 更新配置文件/文档/依赖 by @imClumsyPanda @zRzRzRzRzRzRzR
- 添加日文版 readme by @eltociear #2787

修复:
- langchain 更新后,PGVector 向量库连接错误 by @HALIndex #2591
- Minimax's model worker 错误 by @xyhshen 
- ES库无法向量检索.添加mappings创建向量索引 by MSZheng20 #2688
This commit is contained in:
liunux4odoo 2024-01-26 06:58:49 +08:00 committed by GitHub
parent ee6a28b565
commit 9c525b7fa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
69 changed files with 1300 additions and 1547 deletions

View File

@ -1,6 +1,5 @@
![](img/logo-long-chatchat-trans-v2.png) ![](img/logo-long-chatchat-trans-v2.png)
🌍 [READ THIS IN ENGLISH](README_en.md) 🌍 [READ THIS IN ENGLISH](README_en.md)
🌍 [日本語で読む](README_ja.md) 🌍 [日本語で読む](README_ja.md)
@ -8,6 +7,8 @@
基于 ChatGLM 等大语言模型与 Langchain 等应用框架实现,开源、可离线部署的检索增强生成(RAG)大模型知识库项目。 基于 ChatGLM 等大语言模型与 Langchain 等应用框架实现,开源、可离线部署的检索增强生成(RAG)大模型知识库项目。
⚠️`0.2.10`将会是`0.2.x`系列的最后一个版本,`0.2.x`系列版本将会停止更新和技术支持,全力研发具有更强应用性的 `Langchain-Chatchat 0.3.x`
--- ---
## 目录 ## 目录
@ -15,23 +16,31 @@
* [介绍](README.md#介绍) * [介绍](README.md#介绍)
* [解决的痛点](README.md#解决的痛点) * [解决的痛点](README.md#解决的痛点)
* [快速上手](README.md#快速上手) * [快速上手](README.md#快速上手)
* [1. 环境配置](README.md#1-环境配置) * [1. 环境配置](README.md#1-环境配置)
* [2. 模型下载](README.md#2-模型下载) * [2. 模型下载](README.md#2-模型下载)
* [3. 初始化知识库和配置文件](README.md#3-初始化知识库和配置文件) * [3. 初始化知识库和配置文件](README.md#3-初始化知识库和配置文件)
* [4. 一键启动](README.md#4-一键启动) * [4. 一键启动](README.md#4-一键启动)
* [5. 启动界面示例](README.md#5-启动界面示例) * [5. 启动界面示例](README.md#5-启动界面示例)
* [联系我们](README.md#联系我们) * [联系我们](README.md#联系我们)
## 介绍 ## 介绍
🤖️ 一种利用 [langchain](https://github.com/hwchase17/langchain) 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。 🤖️ 一种利用 [langchain](https://github.com/langchain-ai/langchain)
思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。
💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全流程可使用开源模型实现的本地知识库问答应用。本项目的最新版本中通过使用 [FastChat](https://github.com/lm-sys/FastChat) 接入 Vicuna, Alpaca, LLaMA, Koala, RWKV 等模型,依托于 [langchain](https://github.com/langchain-ai/langchain) 框架支持通过基于 [FastAPI](https://github.com/tiangolo/fastapi) 提供的 API 调用服务,或使用基于 [Streamlit](https://github.com/streamlit/streamlit) 的 WebUI 进行操作。 💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai)
和 [AlexZhangji](https://github.com/AlexZhangji)
创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216)
启发,建立了全流程可使用开源模型实现的本地知识库问答应用。本项目的最新版本中通过使用 [FastChat](https://github.com/lm-sys/FastChat)
接入 Vicuna, Alpaca, LLaMA, Koala, RWKV 等模型,依托于 [langchain](https://github.com/langchain-ai/langchain)
框架支持通过基于 [FastAPI](https://github.com/tiangolo/fastapi) 提供的 API
调用服务,或使用基于 [Streamlit](https://github.com/streamlit/streamlit) 的 WebUI 进行操作。
✅ 依托于本项目支持的开源 LLM 与 Embedding 模型,本项目可实现全部使用**开源**模型**离线私有部署**。与此同时,本项目也支持 OpenAI GPT API 的调用,并将在后续持续扩充对各类模型及模型 API 的接入。 ✅ 依托于本项目支持的开源 LLM 与 Embedding 模型,本项目可实现全部使用**开源**模型**离线私有部署**。与此同时,本项目也支持
OpenAI GPT API 的调用,并将在后续持续扩充对各类模型及模型 API 的接入。
⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的 `top k`个 -> 匹配出的文本作为上下文和问题一起添加到 `prompt`中 -> 提交给 `LLM`生成回答。 ⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 ->
在文本向量中匹配出与问句向量最相似的 `top k`个 -> 匹配出的文本作为上下文和问题一起添加到 `prompt`中 -> 提交给 `LLM`生成回答。
📺 [原理介绍视频](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514) 📺 [原理介绍视频](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514)
@ -43,7 +52,8 @@
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v13` 版本所使用代码已更新至本项目 `v0.2.9` 版本。 🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v13`
版本所使用代码已更新至本项目 `v0.2.9` 版本。
🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.6) 已经更新到 ```0.2.7``` 版本。 🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.6) 已经更新到 ```0.2.7``` 版本。
@ -53,7 +63,10 @@
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7 docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7
``` ```
🧩 本项目有一个非常完整的[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) README只是一个简单的介绍__仅仅是入门教程能够基础运行__。 如果你想要更深入的了解本项目,或者想对本项目做出贡献。请移步 [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) 界面 🧩 本项目有一个非常完整的[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) README只是一个简单的介绍_
_仅仅是入门教程能够基础运行__。
如果你想要更深入的了解本项目,或者想对本项目做出贡献。请移步 [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
界面
## 解决的痛点 ## 解决的痛点
@ -63,17 +76,19 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
我们支持市面上主流的本地大语言模型和Embedding模型支持开源的本地向量数据库。 我们支持市面上主流的本地大语言模型和Embedding模型支持开源的本地向量数据库。
支持列表详见[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) 支持列表详见[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
## 快速上手 ## 快速上手
### 1. 环境配置 ### 1. 环境配置
+ 首先,确保你的机器安装了 Python 3.8 - 3.11 + 首先,确保你的机器安装了 Python 3.8 - 3.11 (我们强烈推荐使用 Python3.11)。
``` ```
$ python --version $ python --version
Python 3.11.7 Python 3.11.7
``` ```
接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖 接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖
```shell ```shell
# 拉取仓库 # 拉取仓库
@ -89,33 +104,44 @@ $ pip install -r requirements_webui.txt
# 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。 # 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
``` ```
请注意LangChain-Chatchat `0.2.x` 系列是针对 Langchain `0.0.x` 系列版本的,如果你使用的是 Langchain `0.1.x` 系列版本,需要降级。
请注意LangChain-Chatchat `0.2.x` 系列是针对 Langchain `0.0.x` 系列版本的,如果你使用的是 Langchain `0.1.x`
系列版本,需要降级您的`Langchain`版本。
### 2 模型下载 ### 2 模型下载
如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding 模型可以从 [HuggingFace](https://huggingface.co/models) 下载。 如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding
模型可以从 [HuggingFace](https://huggingface.co/models) 下载。
以本项目中默认使用的 LLM 模型 [THUDM/ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b) 与 Embedding 模型 [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) 为例: 以本项目中默认使用的 LLM 模型 [THUDM/ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b) 与 Embedding
模型 [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) 为例:
下载模型需要先[安装 Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage),然后运行 下载模型需要先[安装 Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage)
,然后运行
```Shell ```Shell
$ git lfs install $ git lfs install
$ git clone https://huggingface.co/THUDM/chatglm3-6b $ git clone https://huggingface.co/THUDM/chatglm3-6b
$ git clone https://huggingface.co/BAAI/bge-large-zh $ git clone https://huggingface.co/BAAI/bge-large-zh
``` ```
### 3. 初始化知识库和配置文件 ### 3. 初始化知识库和配置文件
按照下列方式初始化自己的知识库和简单的复制配置文件 按照下列方式初始化自己的知识库和简单的复制配置文件
```shell ```shell
$ python copy_config_example.py $ python copy_config_example.py
$ python init_database.py --recreate-vs $ python init_database.py --recreate-vs
``` ```
### 4. 一键启动 ### 4. 一键启动
按照以下命令启动项目 按照以下命令启动项目
```shell ```shell
$ python startup.py -a $ python startup.py -a
``` ```
### 5. 启动界面示例 ### 5. 启动界面示例
如果正常启动,你将能看到以下界面 如果正常启动,你将能看到以下界面
@ -134,19 +160,32 @@ $ python startup.py -a
![](img/init_knowledge_base.jpg) ![](img/init_knowledge_base.jpg)
### 注意 ### 注意
以上方式只是为了快速上手,如果需要更多的功能和自定义启动方式 ,请参考[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) 以上方式只是为了快速上手,如果需要更多的功能和自定义启动方式
,请参考[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
--- ---
## 项目里程碑 ## 项目里程碑
+ `2023年4月`: `Langchain-ChatGLM 0.1.0` 发布,支持基于 ChatGLM-6B 模型的本地知识库问答。
+ `2023年8月`: `Langchain-ChatGLM` 改名为 `Langchain-Chatchat``0.2.0` 发布,使用 `fastchat` 作为模型加载方案,支持更多的模型和数据库。
+ `2023年10月`: `Langchain-Chatchat 0.2.5` 发布,推出 Agent 内容,开源项目在`Founder Park & Zhipu AI & Zilliz`
举办的黑客马拉松获得三等奖。
+ `2023年12月`: `Langchain-Chatchat` 开源项目获得超过 **20K** stars.
+ `2024年1月`: `LangChain 0.1.x` 推出,`Langchain-Chatchat 0.2.x` 发布稳定版本`0.2.10`
后将停止更新和技术支持,全力研发具有更强应用性的 `Langchain-Chatchat 0.3.x`
+ 🔥 让我们一起期待未来 Chatchat 的故事 ···
--- ---
## 联系我们 ## 联系我们
### Telegram ### Telegram
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9) [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
### 项目交流群 ### 项目交流群
@ -158,4 +197,4 @@ $ python startup.py -a
<img src="img/official_wechat_mp_account.png" alt="二维码" width="300" /> <img src="img/official_wechat_mp_account.png" alt="二维码" width="300" />
🎉 Langchain-Chatchat 项目官方公众号,欢迎扫码关注。 🎉 Langchain-Chatchat 项目官方公众号,欢迎扫码关注。

View File

@ -8,6 +8,10 @@
A LLM application aims to implement knowledge and search engine based QA based on Langchain and open-source or remote A LLM application aims to implement knowledge and search engine based QA based on Langchain and open-source or remote
LLM API. LLM API.
⚠️`0.2.10` will be the last version of the `0.2.x` series. The `0.2.x` series will stop updating and technical support,
and strive to develop `Langchain-Chachat 0.3.x with stronger applicability. `.
--- ---
## Table of Contents ## Table of Contents
@ -25,7 +29,8 @@ LLM API.
## Introduction ## Introduction
🤖️ A Q&A application based on local knowledge base implemented using the idea 🤖️ A Q&A application based on local knowledge base implemented using the idea
of [langchain](https://github.com/hwchase17/langchain). The goal is to build a KBQA(Knowledge based Q&A) solution that of [langchain](https://github.com/langchain-ai/langchain). The goal is to build a KBQA(Knowledge based Q&A) solution
that
is friendly to Chinese scenarios and open source models and can run both offline and online. is friendly to Chinese scenarios and open source models and can run both offline and online.
💡 Inspired by [document.ai](https://github.com/GanymedeNil/document.ai) 💡 Inspired by [document.ai](https://github.com/GanymedeNil/document.ai)
@ -56,10 +61,9 @@ The main process analysis from the aspect of document process:
🚩 The training or fine-tuning are not involved in the project, but still, one always can improve performance by do 🚩 The training or fine-tuning are not involved in the project, but still, one always can improve performance by do
these. these.
🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v13 the codes are update 🌐 [AutoDL image](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) is supported, and in v13 the codes are update to v0.2.9.
to v0.2.9.
🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7) 🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7) is supported to 0.2.7
## Pain Points Addressed ## Pain Points Addressed
@ -99,7 +103,9 @@ $ pip install -r requirements_webui.txt
# 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。 # 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
``` ```
Please note that the LangChain-Chachat `0.2.x` series is for the Langchain `0.0.x` series version. If you are using the Langchain `0.1.x` series version, you need to downgrade.
Please note that the LangChain-Chachat `0.2.x` series is for the Langchain `0.0.x` series version. If you are using the
Langchain `0.1.x` series version, you need to downgrade.
### Model Download ### Model Download
@ -159,15 +165,33 @@ please refer to the [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/
--- ---
## Project Milestones
+ `April 2023`: `Langchain-ChatGLM 0.1.0` released, supporting local knowledge base question and answer based on the
ChatGLM-6B model.
+ `August 2023`: `Langchain-ChatGLM` was renamed to `Langchain-Chatchat`, `0.2.0` was released, using `fastchat` as the
model loading solution, supporting more models and databases.
+ `October 2023`: `Langchain-Chachat 0.2.5` was released, Agent content was launched, and the open source project won
the third prize in the hackathon held by `Founder Park & Zhipu AI & Zilliz`.
+ `December 2023`: `Langchain-Chachat` open source project received more than **20K** stars.
+ `January 2024`: `LangChain 0.1.x` is launched, `Langchain-Chachat 0.2.x` is released. After the stable
version `0.2.10` is released, updates and technical support will be stopped, and all efforts will be made to
develop `Langchain with stronger applicability -Chat 0.3.x`.
+ 🔥 Lets look forward to the future Chatchat stories together···
---
## Contact Us ## Contact Us
### Telegram ### Telegram
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9) [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
### WeChat Group、 ### WeChat Group
<img src="img/qr_code_67.jpg" alt="二维码" width="300" height="300" /> <img src="img/qr_code_87.jpg" alt="二维码" width="300" height="300" />
### WeChat Official Account ### WeChat Official Account

View File

@ -5,4 +5,4 @@ from .server_config import *
from .prompt_config import * from .prompt_config import *
VERSION = "v0.2.9" VERSION = "v0.2.10"

View File

@ -21,10 +21,9 @@ OVERLAP_SIZE = 50
# 知识库匹配向量数量 # 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3 VECTOR_SEARCH_TOP_K = 3
# 知识库匹配的距离阈值取值范围在0-1之间SCORE越小距离越小从而相关度越高 # 知识库匹配的距离阈值一般取值范围在0-1之间SCORE越小距离越小从而相关度越高。
# 取到1相当于不筛选实测bge-large的距离得分大部分在0.01-0.7之间, # 但有用户报告遇到过匹配分值超过1的情况为了兼容性默认设为1在WEBUI中调整范围为0-2
# 相似文本的得分最高在0.55左右因此建议针对bge设置得分为0.6 SCORE_THRESHOLD = 1.0
SCORE_THRESHOLD = 0.6
# 默认搜索引擎。可选bing, duckduckgo, metaphor # 默认搜索引擎。可选bing, duckduckgo, metaphor
DEFAULT_SEARCH_ENGINE = "duckduckgo" DEFAULT_SEARCH_ENGINE = "duckduckgo"
@ -49,12 +48,17 @@ BING_SUBSCRIPTION_KEY = ""
# metaphor搜索需要KEY # metaphor搜索需要KEY
METAPHOR_API_KEY = "" METAPHOR_API_KEY = ""
# 心知天气 API KEY用于天气Agent。申请https://www.seniverse.com/
SENIVERSE_API_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置 # 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记 # 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False ZH_TITLE_ENHANCE = False
# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
PDF_OCR_THRESHOLD = (0.6, 0.6)
# 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。 # 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。
KB_INFO = { KB_INFO = {

View File

@ -6,9 +6,9 @@ import os
MODEL_ROOT_PATH = "" MODEL_ROOT_PATH = ""
# 选用的 Embedding 名称 # 选用的 Embedding 名称
EMBEDDING_MODEL = "bge-large-zh" EMBEDDING_MODEL = "bge-large-zh-v1.5"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 # Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
EMBEDDING_DEVICE = "auto" EMBEDDING_DEVICE = "auto"
# 选用的reranker模型 # 选用的reranker模型
@ -26,44 +26,33 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
# 在这里我们使用目前主流的两个离线模型其中chatglm3-6b 为默认加载模型。 # 在这里我们使用目前主流的两个离线模型其中chatglm3-6b 为默认加载模型。
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。 # 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
# chatglm3-6b输出角色标签<|user|>及自问自答的问题详见项目wiki->常见问题->Q20. LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",
# AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0])
Agent_MODEL = None Agent_MODEL = None
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 # LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
LLM_DEVICE = "auto" LLM_DEVICE = "auto"
# 历史对话轮数
HISTORY_LEN = 3 HISTORY_LEN = 3
# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度 MAX_TOKENS = 2048
MAX_TOKENS = None
# LLM通用对话参数
TEMPERATURE = 0.7 TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
ONLINE_LLM_MODEL = { ONLINE_LLM_MODEL = {
# 线上模型。请在server_config中为每个在线API设置不同的端口
"openai-api": { "openai-api": {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-4",
"api_base_url": "https://api.openai.com/v1", "api_base_url": "https://api.openai.com/v1",
"api_key": "", "api_key": "",
"openai_proxy": "", "openai_proxy": "",
}, },
# 具体注册及api key获取请前往 http://open.bigmodel.cn # 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": { "zhipu-api": {
"api_key": "", "api_key": "",
"version": "chatglm_turbo", # 可选包括 "chatglm_turbo" "version": "glm-4",
"provider": "ChatGLMWorker", "provider": "ChatGLMWorker",
}, },
# 具体注册及api key获取请前往 https://api.minimax.chat/ # 具体注册及api key获取请前往 https://api.minimax.chat/
"minimax-api": { "minimax-api": {
"group_id": "", "group_id": "",
@ -72,13 +61,12 @@ ONLINE_LLM_MODEL = {
"provider": "MiniMaxWorker", "provider": "MiniMaxWorker",
}, },
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/ # 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
"xinghuo-api": { "xinghuo-api": {
"APPID": "", "APPID": "",
"APISecret": "", "APISecret": "",
"api_key": "", "api_key": "",
"version": "v1.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v1.5", "v2.0" "version": "v3.0", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v2.0", "v1.5"
"provider": "XingHuoWorker", "provider": "XingHuoWorker",
}, },
@ -93,8 +81,8 @@ ONLINE_LLM_MODEL = {
# 火山方舟 API文档参考 https://www.volcengine.com/docs/82379 # 火山方舟 API文档参考 https://www.volcengine.com/docs/82379
"fangzhou-api": { "fangzhou-api": {
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model" 更多的见文档模型支持列表中方舟部分。 "version": "chatglm-6b-model",
"version_url": "", # 可以不填写version直接填写在方舟申请模型发布的API地址 "version_url": "",
"api_key": "", "api_key": "",
"secret_key": "", "secret_key": "",
"provider": "FangZhouWorker", "provider": "FangZhouWorker",
@ -102,15 +90,15 @@ ONLINE_LLM_MODEL = {
# 阿里云通义千问 API文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details # 阿里云通义千问 API文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
"qwen-api": { "qwen-api": {
"version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus" "version": "qwen-max",
"api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建 "api_key": "",
"provider": "QwenWorker", "provider": "QwenWorker",
"embed_model": "text-embedding-v1" # embedding 模型名称 "embed_model": "text-embedding-v1" # embedding 模型名称
}, },
# 百川 API申请方式请参考 https://www.baichuan-ai.com/home#api-enter # 百川 API申请方式请参考 https://www.baichuan-ai.com/home#api-enter
"baichuan-api": { "baichuan-api": {
"version": "Baichuan2-53B", # 当前支持 "Baichuan2-53B" 见官方文档。 "version": "Baichuan2-53B",
"api_key": "", "api_key": "",
"secret_key": "", "secret_key": "",
"provider": "BaiChuanWorker", "provider": "BaiChuanWorker",
@ -132,6 +120,11 @@ ONLINE_LLM_MODEL = {
"secret_key": "", "secret_key": "",
"provider": "TianGongWorker", "provider": "TianGongWorker",
}, },
# Gemini API https://makersuite.google.com/app/apikey
"gemini-api": {
"api_key": "",
"provider": "GeminiWorker",
}
} }
@ -143,6 +136,7 @@ ONLINE_LLM_MODEL = {
# - GanymedeNil/text2vec-large-chinese # - GanymedeNil/text2vec-large-chinese
# - text2vec-large-chinese # - text2vec-large-chinese
# 2.2 如果以上本地路径不存在则使用huggingface模型 # 2.2 如果以上本地路径不存在则使用huggingface模型
MODEL_PATH = { MODEL_PATH = {
"embed_model": { "embed_model": {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
@ -169,55 +163,59 @@ MODEL_PATH = {
}, },
"llm_model": { "llm_model": {
# 以下部分模型并未完全测试仅根据fastchat和vllm模型的模型列表推定支持
"chatglm2-6b": "THUDM/chatglm2-6b", "chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b", "chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"chatglm3-6b-base": "THUDM/chatglm3-6b-base",
"Qwen-1_8B": "Qwen/Qwen-1_8B", "Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat",
"Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin",
"Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat",
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", "Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
# 在新版的transformers下需要手动修改模型的config.json文件在quantization_config字典中
# 增加`disable_exllama:true` 字段才能启动qwen的量化模型
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
"Qwen-72B": "Qwen/Qwen-72B",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat", "baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat", "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat",
"aquila-7b": "BAAI/Aquila-7B", "baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"internlm-7b": "internlm/internlm-7b", "internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b", "internlm-chat-7b": "internlm/internlm-chat-7b",
"internlm2-chat-7b": "internlm/internlm2-chat-7b",
"internlm2-chat-20b": "internlm/internlm2-chat-20b",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
"falcon-7b": "tiiuae/falcon-7b", "falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b", "falcon-40b": "tiiuae/falcon-40b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b", "falcon-rw-7b": "tiiuae/falcon-rw-7b",
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.5": "lmsys/vicuna-13b-v1.5",
"koala": "young-geng/koala",
"mpt-7b": "mosaicml/mpt-7b",
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
"mpt-30b": "mosaicml/mpt-30b",
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"gpt2": "gpt2", "gpt2": "gpt2",
"gpt2-xl": "gpt2-xl", "gpt2-xl": "gpt2-xl",
"gpt-j-6b": "EleutherAI/gpt-j-6b", "gpt-j-6b": "EleutherAI/gpt-j-6b",
"gpt4all-j": "nomic-ai/gpt4all-j", "gpt4all-j": "nomic-ai/gpt4all-j",
"gpt-neox-20b": "EleutherAI/gpt-neox-20b", "gpt-neox-20b": "EleutherAI/gpt-neox-20b",
@ -225,63 +223,51 @@ MODEL_PATH = {
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b", "dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b", "stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
"koala": "young-geng/koala",
"mpt-7b": "mosaicml/mpt-7b",
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
"mpt-30b": "mosaicml/mpt-30b",
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
"Yi-34B-Chat": "01-ai/Yi-34B-Chat",
}, },
"reranker":{
"bge-reranker-large":"BAAI/bge-reranker-large", "reranker": {
"bge-reranker-base":"BAAI/bge-reranker-base", "bge-reranker-large": "BAAI/bge-reranker-large",
#TODO 增加在线reranker如cohere "bge-reranker-base": "BAAI/bge-reranker-base",
} }
} }
# 通常情况下不需要更改以下内容 # 通常情况下不需要更改以下内容
# nltk 模型存储路径 # nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 使用VLLM可能导致模型推理能力下降无法完成Agent任务
VLLM_MODEL_DICT = { VLLM_MODEL_DICT = {
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"chatglm2-6b": "THUDM/chatglm2-6b", "chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b", "chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat", "BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k", "BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
# 注意bloom系列的tokenizer与model是分离的因此虽然vllm支持但与fschat框架不兼容
# "bloom": "bigscience/bloom",
# "bloomz": "bigscience/bloomz",
# "bloomz-560m": "bigscience/bloomz-560m",
# "bloomz-7b1": "bigscience/bloomz-7b1",
# "bloomz-1b7": "bigscience/bloomz-1b7",
"internlm-7b": "internlm/internlm-7b", "internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b", "internlm-chat-7b": "internlm/internlm-chat-7b",
"internlm2-chat-7b": "internlm/Models/internlm2-chat-7b",
"internlm2-chat-20b": "internlm/Models/internlm2-chat-20b",
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"falcon-7b": "tiiuae/falcon-7b", "falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b", "falcon-40b": "tiiuae/falcon-40b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b", "falcon-rw-7b": "tiiuae/falcon-rw-7b",
@ -294,8 +280,6 @@ VLLM_MODEL_DICT = {
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b", "dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b", "stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"open_llama_13b": "openlm-research/open_llama_13b", "open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3", "vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
"koala": "young-geng/koala", "koala": "young-geng/koala",
@ -305,37 +289,14 @@ VLLM_MODEL_DICT = {
"opt-66b": "facebook/opt-66b", "opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b", "opt-iml-max-30b": "facebook/opt-iml-max-30b",
"Qwen-1_8B": "Qwen/Qwen-1_8B",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
"Qwen-72B": "Qwen/Qwen-72B",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
} }
# 你认为支持Agent能力的模型可以在这里添加添加后不会出现可视化界面的警告
# 经过我们测试原生支持Agent的模型仅有以下几个
SUPPORT_AGENT_MODEL = [ SUPPORT_AGENT_MODEL = [
"azure-api", "openai-api", # GPT4 模型
"openai-api", "qwen-api", # Qwen Max模型
"qwen-api", "zhipu-api", # 智谱AI GLM4模型
"Qwen", "Qwen", # 所有Qwen系列本地模型
"chatglm3", "chatglm3-6b",
"xinghuo-api", "internlm2-chat-20b",
"Orion-14B-Chat-Plugin",
] ]

View File

@ -40,8 +40,6 @@ FSCHAT_MODEL_WORKERS = {
"device": LLM_DEVICE, "device": LLM_DEVICE,
# False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题参见doc/FAQ # False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题参见doc/FAQ
# vllm对一些模型支持还不成熟暂时默认关闭 # vllm对一些模型支持还不成熟暂时默认关闭
# fschat=0.2.33的代码有bug, 如需使用源码修改fastchat.server.vllm_worker
# 将103行中sampling_params = SamplingParams的参数stop=list(stop)修改为stop= [i for i in stop if i!=""]
"infer_turbo": False, "infer_turbo": False,
# model_worker多卡加载需要配置的参数 # model_worker多卡加载需要配置的参数
@ -92,11 +90,10 @@ FSCHAT_MODEL_WORKERS = {
# 'disable_log_requests': False # 'disable_log_requests': False
}, },
# 可以如下示例方式更改默认配置 "Qwen-1_8B-Chat": {
# "Qwen-1_8B-Chat": { # 使用default中的IP和端口 "device": "cpu",
# "device": "cpu", },
# }, "chatglm3-6b": {
"chatglm3-6b": { # 使用default中的IP和端口
"device": "cuda", "device": "cuda",
}, },
@ -128,14 +125,11 @@ FSCHAT_MODEL_WORKERS = {
"tiangong-api": { "tiangong-api": {
"port": 21009, "port": 21009,
}, },
"gemini-api": {
"port": 21010,
},
} }
# fastchat multi model worker server
FSCHAT_MULTI_MODEL_WORKERS = {
# TODO:
}
# fastchat controller server
FSCHAT_CONTROLLER = { FSCHAT_CONTROLLER = {
"host": DEFAULT_BIND_HOST, "host": DEFAULT_BIND_HOST,
"port": 20001, "port": 20001,

View File

@ -1,2 +1,4 @@
from .mypdfloader import RapidOCRPDFLoader from .mypdfloader import RapidOCRPDFLoader
from .myimgloader import RapidOCRLoader from .myimgloader import RapidOCRLoader
from .mydocloader import RapidOCRDocLoader
from .mypptloader import RapidOCRPPTLoader

View File

@ -0,0 +1,71 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm
class RapidOCRDocLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def doc2text(filepath):
from docx.table import _Cell, Table
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.text.paragraph import Paragraph
from docx import Document, ImagePart
from PIL import Image
from io import BytesIO
import numpy as np
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
doc = Document(filepath)
resp = ""
def iter_block_items(parent):
from docx.document import Document
if isinstance(parent, Document):
parent_elm = parent.element.body
elif isinstance(parent, _Cell):
parent_elm = parent._tc
else:
raise ValueError("RapidOCRDocLoader parse fail")
for child in parent_elm.iterchildren():
if isinstance(child, CT_P):
yield Paragraph(child, parent)
elif isinstance(child, CT_Tbl):
yield Table(child, parent)
b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables),
desc="RapidOCRDocLoader block index: 0")
for i, block in enumerate(iter_block_items(doc)):
b_unit.set_description(
"RapidOCRDocLoader block index: {}".format(i))
b_unit.refresh()
if isinstance(block, Paragraph):
resp += block.text.strip() + "\n"
images = block._element.xpath('.//pic:pic') # 获取所有图片
for image in images:
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
if isinstance(part, ImagePart):
image = Image.open(BytesIO(part._blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif isinstance(block, Table):
for row in block.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
resp += paragraph.text.strip() + "\n"
b_unit.update(1)
return resp
text = doc2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == '__main__':
loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx")
docs = loader.load()
print(docs)

View File

@ -1,5 +1,6 @@
from typing import List from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader from langchain.document_loaders.unstructured import UnstructuredFileLoader
from configs import PDF_OCR_THRESHOLD
from document_loaders.ocr import get_ocr from document_loaders.ocr import get_ocr
import tqdm import tqdm
@ -15,23 +16,25 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0") b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0")
for i, page in enumerate(doc): for i, page in enumerate(doc):
# 更新描述
b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i)) b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i))
# 立即显示进度条更新结果
b_unit.refresh() b_unit.refresh()
# TODO: 依据文本与图片顺序调整处理方式
text = page.get_text("") text = page.get_text("")
resp += text + "\n" resp += text + "\n"
img_list = page.get_images() img_list = page.get_image_info(xrefs=True)
for img in img_list: for img in img_list:
pix = fitz.Pixmap(doc, img[0]) if xref := img.get("xref"):
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) bbox = img["bbox"]
result, _ = ocr(img_array) # 检查图片尺寸是否超过设定的阈值
if result: if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0]
ocr_result = [line[1] for line in result] or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]):
resp += "\n".join(ocr_result) continue
pix = fitz.Pixmap(doc, xref)
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
result, _ = ocr(img_array)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
# 更新进度 # 更新进度
b_unit.update(1) b_unit.update(1)

View File

@ -0,0 +1,59 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm
class RapidOCRPPTLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def ppt2text(filepath):
from pptx import Presentation
from PIL import Image
import numpy as np
from io import BytesIO
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
prs = Presentation(filepath)
resp = ""
def extract_text(shape):
nonlocal resp
if shape.has_text_frame:
resp += shape.text.strip() + "\n"
if shape.has_table:
for row in shape.table.rows:
for cell in row.cells:
for paragraph in cell.text_frame.paragraphs:
resp += paragraph.text.strip() + "\n"
if shape.shape_type == 13: # 13 表示图片
image = Image.open(BytesIO(shape.image.blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif shape.shape_type == 6: # 6 表示组合
for child_shape in shape.shapes:
extract_text(child_shape)
b_unit = tqdm.tqdm(total=len(prs.slides),
desc="RapidOCRPPTLoader slide index: 1")
# 遍历所有幻灯片
for slide_number, slide in enumerate(prs.slides, start=1):
b_unit.set_description(
"RapidOCRPPTLoader slide index: {}".format(slide_number))
b_unit.refresh()
sorted_shapes = sorted(slide.shapes,
key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历
for shape in sorted_shapes:
extract_text(shape)
b_unit.update(1)
return resp
text = ppt2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == '__main__':
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx")
docs = loader.load()
print(docs)

View File

@ -7,31 +7,35 @@
保存的模型的位置位于原本嵌入模型的目录下模型的名称为原模型名称+Merge_Keywords_时间戳 保存的模型的位置位于原本嵌入模型的目录下模型的名称为原模型名称+Merge_Keywords_时间戳
''' '''
import sys import sys
sys.path.append("..") sys.path.append("..")
import os
import torch
from datetime import datetime from datetime import datetime
from configs import ( from configs import (
MODEL_PATH, MODEL_PATH,
EMBEDDING_MODEL, EMBEDDING_MODEL,
EMBEDDING_KEYWORD_FILE, EMBEDDING_KEYWORD_FILE,
) )
import os
import torch
from safetensors.torch import save_model from safetensors.torch import save_model
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from langchain_core._api import deprecated
@deprecated(
since="0.3.0",
message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def get_keyword_embedding(bert_model, tokenizer, key_words): def get_keyword_embedding(bert_model, tokenizer, key_words):
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True) tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
# No need to manually convert to tensor as we've set return_tensors="pt"
input_ids = tokenizer_output['input_ids'] input_ids = tokenizer_output['input_ids']
# Remove the first and last token for each sequence in the batch
input_ids = input_ids[:, 1:-1] input_ids = input_ids[:, 1:-1]
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids) keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
keyword_embedding = torch.mean(keyword_embedding, 1) keyword_embedding = torch.mean(keyword_embedding, 1)
return keyword_embedding return keyword_embedding
@ -47,14 +51,11 @@ def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", out
bert_model = word_embedding_model.auto_model bert_model = word_embedding_model.auto_model
tokenizer = word_embedding_model.tokenizer tokenizer = word_embedding_model.tokenizer
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words) key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
# key_words_embedding = st_model.encode(key_words)
embedding_weight = bert_model.embeddings.word_embeddings.weight embedding_weight = bert_model.embeddings.word_embeddings.weight
embedding_weight_len = len(embedding_weight) embedding_weight_len = len(embedding_weight)
tokenizer.add_tokens(key_words) tokenizer.add_tokens(key_words)
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
# key_words_embedding_tensor = torch.from_numpy(key_words_embedding)
embedding_weight = bert_model.embeddings.word_embeddings.weight embedding_weight = bert_model.embeddings.word_embeddings.weight
with torch.no_grad(): with torch.no_grad():
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
@ -76,46 +77,3 @@ def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time) output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
output_model_path = os.path.join(model_parent_directory, output_model_name) output_model_path = os.path.join(model_parent_directory, output_model_name)
add_keyword_to_model(model_name, keyword_file, output_model_path) add_keyword_to_model(model_name, keyword_file, output_model_path)
if __name__ == '__main__':
add_keyword_to_embedding_model(EMBEDDING_KEYWORD_FILE)
# input_model_name = ""
# output_model_path = ""
# # 以下为加入关键字前后tokenizer的测试用例对比
# def print_token_ids(output, tokenizer, sentences):
# for idx, ids in enumerate(output['input_ids']):
# print(f'sentence={sentences[idx]}')
# print(f'ids={ids}')
# for id in ids:
# decoded_id = tokenizer.decode(id)
# print(f' {decoded_id}->{id}')
#
# sentences = [
# '数据科学与大数据技术',
# 'Langchain-Chatchat'
# ]
#
# st_no_keywords = SentenceTransformer(input_model_name)
# tokenizer_without_keywords = st_no_keywords.tokenizer
# print("===== tokenizer with no keywords added =====")
# output = tokenizer_without_keywords(sentences)
# print_token_ids(output, tokenizer_without_keywords, sentences)
# print(f'-------- embedding with no keywords added -----')
# embeddings = st_no_keywords.encode(sentences)
# print(embeddings)
#
# print("--------------------------------------------")
# print("--------------------------------------------")
# print("--------------------------------------------")
#
# st_with_keywords = SentenceTransformer(output_model_path)
# tokenizer_with_keywords = st_with_keywords.tokenizer
# print("===== tokenizer with keyword added =====")
# output = tokenizer_with_keywords(sentences)
# print_token_ids(output, tokenizer_with_keywords, sentences)
#
# print(f'-------- embedding with keywords added -----')
# embeddings = st_with_keywords.encode(sentences)
# print(embeddings)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 272 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 218 KiB

After

Width:  |  Height:  |  Size: 318 KiB

@ -1 +1 @@
Subproject commit 2f24adb218f23eab00d7fcd7ccf5072f2f35cb3c Subproject commit 28f664aa08f8191a70339c9ecbe7a89b35a1032a

View File

@ -1,77 +1,66 @@
# API requirements torch==2.1.2
torchvision==0.16.2
torch~=2.1.2 torchaudio==2.1.2
torchvision~=0.16.2
torchaudio~=2.1.2
xformers==0.0.23.post1 xformers==0.0.23.post1
transformers==4.36.2 transformers==4.36.2
sentence_transformers==2.2.2 sentence_transformers==2.2.2
langchain==0.0.354 langchain==0.0.354
langchain-experimental==0.0.47 langchain-experimental==0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.34 fschat==0.2.35
openai~=1.7.1 openai==1.9.0
fastapi~=0.108.0 fastapi==0.109.0
sse_starlette==1.8.2 sse_starlette==1.8.2
nltk>=3.8.1 nltk==3.8.1
uvicorn>=0.24.0.post1 uvicorn==0.24.0.post1
starlette~=0.32.0 starlette==0.35.0
unstructured[all-docs]==0.11.0 unstructured[all-docs] # ==0.11.8
python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy==2.0.25
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus faiss-cpu==1.7.4
accelerate~=0.24.1 accelerate==0.24.1
spacy~=3.7.2 spacy==3.7.2
PyMuPDF~=1.23.8 PyMuPDF==1.23.16
rapidocr_onnxruntime==1.3.8 rapidocr_onnxruntime==1.3.8
requests~=2.31.0 requests==2.31.0
pathlib~=1.0.1 pathlib==1.0.1
pytest~=7.4.3 pytest==7.4.3
numexpr~=2.8.6 # max version for py38 numexpr==2.8.6
strsimpy~=0.2.1 strsimpy==0.2.1
markdownify~=0.11.6 markdownify==0.11.6
tiktoken~=0.5.2 tiktoken==0.5.2
tqdm>=4.66.1 tqdm==4.66.1
websockets>=12.0 websockets==12.0
numpy~=1.24.4 numpy==1.24.4
pandas~=2.0.3 pandas==2.0.3
einops>=0.7.0 einops==0.7.0
transformers_stream_generator==0.0.4 transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux" vllm==0.2.7; sys_platform == "linux"
httpx[brotli,http2,socks]==0.25.2 llama-index==0.9.35
llama-index
# optional document loaders
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2
# Online api libs dependencies
zhipuai==1.0.7 # zhipu
dashscope==1.13.6 # qwen
# volcengine>=1.0.119 # fangzhou
#jq==1.6.0
# beautifulsoup4==4.12.2
# pysrt==1.1.2
# dashscope==1.13.6 # qwen
# volcengine==1.0.119 # fangzhou
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store
# pymilvus>=2.3.4 # pymilvus==2.3.4
# psycopg2==2.9.9 # psycopg2==2.9.9
# pgvector>=0.2.4 # pgvector==0.2.4
#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat
#autoawq==0.1.8 # For Int4
#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files
# Agent and Search Tools arxiv==2.1.0
youtube-search==2.1.2
arxiv~=2.1.0 duckduckgo-search==3.9.9
youtube-search~=2.1.2 metaphor-python==0.1.23
duckduckgo-search~=3.9.9 streamlit==1.30.0
metaphor-python~=0.1.23 streamlit-option-menu==0.3.12
streamlit-antd-components==0.3.1
# WebUI requirements
streamlit~=1.29.0
streamlit-option-menu>=0.3.6
streamlit-chatbox==1.1.11 streamlit-chatbox==1.1.11
streamlit-modal>=0.1.0 streamlit-modal==0.1.0
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid==0.3.4.post3
watchdog>=3.0.0 httpx==0.26.0
watchdog==3.0.0
jwt==1.3.1

View File

@ -1,24 +1,23 @@
torch~=2.1.2 torch~=2.1.2
torchvision~=0.16.2 torchvision~=0.16.2
torchaudio~=2.1.2 torchaudio~=2.1.2
xformers==0.0.23.post1 xformers>=0.0.23.post1
transformers==4.36.2 transformers==4.36.2
sentence_transformers==2.2.2 sentence_transformers==2.2.2
langchain==0.0.354 langchain==0.0.354
langchain-experimental==0.0.47 langchain-experimental==0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.34 fschat==0.2.35
openai~=1.7.1 openai~=1.9.0
fastapi~=0.108.0 fastapi~=0.109.0
sse_starlette==1.8.2 sse_starlette==1.8.2
nltk>=3.8.1 nltk>=3.8.1
uvicorn>=0.24.0.post1 uvicorn>=0.24.0.post1
starlette~=0.32.0 starlette~=0.35.0
unstructured[all-docs]==0.11.0 unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus faiss-cpu~=1.7.4
accelerate~=0.24.1 accelerate~=0.24.1
spacy~=3.7.2 spacy~=3.7.2
PyMuPDF~=1.23.8 PyMuPDF~=1.23.8
@ -26,7 +25,7 @@ rapidocr_onnxruntime==1.3.8
requests~=2.31.0 requests~=2.31.0
pathlib~=1.0.1 pathlib~=1.0.1
pytest~=7.4.3 pytest~=7.4.3
numexpr~=2.8.6 # max version for py38 numexpr~=2.8.6
strsimpy~=0.2.1 strsimpy~=0.2.1
markdownify~=0.11.6 markdownify~=0.11.6
tiktoken~=0.5.2 tiktoken~=0.5.2
@ -36,31 +35,23 @@ numpy~=1.24.4
pandas~=2.0.3 pandas~=2.0.3
einops>=0.7.0 einops>=0.7.0
transformers_stream_generator==0.0.4 transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux" vllm==0.2.7; sys_platform == "linux"
httpx[brotli,http2,socks]==0.25.2 httpx==0.26.0
llama-index llama-index==0.9.35
# optional document loaders # jq==1.6.0
# beautifulsoup4~=4.12.2
# pysrt~=1.1.2
# dashscope==1.13.6
# arxiv~=2.1.0
# youtube-search~=2.1.2
# duckduckgo-search~=3.9.9
# metaphor-python~=0.1.23
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files # volcengine>=1.0.119
jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2
# Online api libs dependencies
zhipuai==1.0.7 # zhipu
dashscope==1.13.6 # qwen
# volcengine>=1.0.119 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus>=2.3.4 # pymilvus>=2.3.4
# psycopg2==2.9.9 # psycopg2==2.9.9
# pgvector>=0.2.4 # pgvector>=0.2.4
#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat
# Agent and Search Tools #autoawq==0.1.8 # For Int4
#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files
arxiv~=2.1.0
youtube-search~=2.1.2
duckduckgo-search~=3.9.9
metaphor-python~=0.1.23

View File

@ -1,64 +1,33 @@
# API requirements
langchain==0.0.354 langchain==0.0.354
langchain-experimental==0.0.47 langchain-experimental==0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.34 fschat~=0.2.35
openai~=1.7.1 openai~=1.9.0
fastapi~=0.108.0 fastapi~=0.109.0
sse_starlette==1.8.2 sse_starlette~=1.8.2
nltk>=3.8.1 nltk~=3.8.1
uvicorn>=0.24.0.post1 uvicorn~=0.24.0.post1
starlette~=0.32.0 starlette~=0.35.0
unstructured[all-docs]==0.11.0 unstructured[all-docs]~=0.12.0
python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy~=2.0.25
faiss-cpu~=1.7.4 faiss-cpu~=1.7.4
accelerate~=0.24.1
spacy~=3.7.2
PyMuPDF~=1.23.16
rapidocr_onnxruntime~=1.3.8
requests~=2.31.0 requests~=2.31.0
pathlib~=1.0.1 pathlib~=1.0.1
pytest~=7.4.3 pytest~=7.4.3
numexpr~=2.8.6 # max version for py38 llama-index==0.9.35
strsimpy~=0.2.1
markdownify~=0.11.6
tiktoken~=0.5.2
tqdm>=4.66.1
websockets>=12.0
numpy~=1.24.4
pandas~=2.0.3
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux"
httpx[brotli,http2,socks]==0.25.2
requests
pathlib
pytest
# Online api libs dependencies
zhipuai==1.0.7
dashscope==1.13.6 dashscope==1.13.6
# volcengine>=1.0.119
# uncomment libs if you want to use corresponding vector store
# pymilvus>=2.3.4
# psycopg2==2.9.9
# pgvector>=0.2.4
# Agent and Search Tools
arxiv~=2.1.0 arxiv~=2.1.0
youtube-search~=2.1.2 youtube-search~=2.1.2
duckduckgo-search~=3.9.9 duckduckgo-search~=3.9.9
metaphor-python~=0.1.23 metaphor-python~=0.1.23
watchdog~=3.0.0
# WebUI requirements # volcengine>=1.0.119
# pymilvus>=2.3.4
streamlit>=1.29.0 # psycopg2==2.9.9
streamlit-option-menu>=0.3.6 # pgvector>=0.2.4
streamlit-antd-components>=0.3.0
streamlit-chatbox>=1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0

View File

@ -1,10 +1,8 @@
# WebUI requirements streamlit~=1.30.0
streamlit-option-menu~=0.3.12
streamlit>=1.29.0 streamlit-antd-components~=0.3.1
streamlit-option-menu>=0.3.6 streamlit-chatbox~=1.1.11
streamlit-antd-components>=0.3.0 streamlit-modal~=0.1.0
streamlit-chatbox>=1.1.11 streamlit-aggrid~=0.3.4.post3
streamlit-modal>=0.1.0 httpx~=0.26.0
streamlit-aggrid>=0.3.4.post3 watchdog~=3.0.0
httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0

View File

@ -1,22 +1,19 @@
""" """
This file is a modified version for ChatGLM3-6B the original ChatGLM3Agent.py file from the langchain repo. This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
""" """
from __future__ import annotations from __future__ import annotations
import yaml
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from typing import Any, List, Sequence, Tuple, Optional, Union
import os
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate, MessagesPlaceholder,
)
import json import json
import logging import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
from pydantic.schema import model_schema
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.agents.agent import AgentOutputParser from langchain.agents.agent import AgentOutputParser
from langchain.output_parsers import OutputFixingParser from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
@ -43,12 +40,18 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index] text = text[:first_index]
if "tool_call" in text: if "tool_call" in text:
tool_name_end = text.find("```") action_end = text.find("```")
tool_name = text[:tool_name_end].strip() action = text[:action_end].strip()
input_para = text.split("='")[-1].split("'")[0] params_str_start = text.find("(") + 1
params_str_end = text.rfind(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
action_json = { action_json = {
"action": tool_name, "action": action,
"action_input": input_para "action_input": params
} }
else: else:
action_json = { action_json = {
@ -109,10 +112,6 @@ class StructuredGLM3ChatAgent(Agent):
else: else:
return agent_scratchpad return agent_scratchpad
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
pass
@classmethod @classmethod
def _get_default_output_parser( def _get_default_output_parser(
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
@ -121,7 +120,7 @@ class StructuredGLM3ChatAgent(Agent):
@property @property
def _stop(self) -> List[str]: def _stop(self) -> List[str]:
return ["```<observation>"] return ["<|observation|>"]
@classmethod @classmethod
def create_prompt( def create_prompt(
@ -131,44 +130,25 @@ class StructuredGLM3ChatAgent(Agent):
input_variables: Optional[List[str]] = None, input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None, memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
def tool_config_from_file(tool_name, directory="server/agent/tools/"):
"""search tool yaml and return simplified json format"""
file_path = os.path.join(directory, f"{tool_name.lower()}.yaml")
try:
with open(file_path, 'r', encoding='utf-8') as file:
tool_config = yaml.safe_load(file)
# Simplify the structure if needed
simplified_config = {
"name": tool_config.get("name", ""),
"description": tool_config.get("description", ""),
"parameters": tool_config.get("parameters", {})
}
return simplified_config
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
return None
except Exception as e:
logger.error(f"An error occurred while reading {file_path}: {e}")
return None
tools_json = [] tools_json = []
tool_names = [] tool_names = []
for tool in tools: for tool in tools:
tool_config = tool_config_from_file(tool.name) tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
if tool_config: simplified_config_langchain = {
tools_json.append(tool_config) "name": tool.name,
tool_names.append(tool.name) "description": tool.description,
"parameters": tool_schema.get("properties", {})
# Format the tools for output }
tools_json.append(simplified_config_langchain)
tool_names.append(tool.name)
formatted_tools = "\n".join([ formatted_tools = "\n".join([
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}" f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
for tool in tools_json for tool in tools_json
]) ])
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}") formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
template = prompt.format(tool_names=tool_names, template = prompt.format(tool_names=tool_names,
tools=formatted_tools, tools=formatted_tools,
history="{history}", history="None",
input="{input}", input="{input}",
agent_scratchpad="{agent_scratchpad}") agent_scratchpad="{agent_scratchpad}")
@ -225,7 +205,6 @@ def initialize_glm3_agent(
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: str = None, prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None,
memory: Optional[ConversationBufferWindowMemory] = None, memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None, agent_kwargs: Optional[dict] = None,
*, *,
@ -238,14 +217,12 @@ def initialize_glm3_agent(
llm=llm, llm=llm,
tools=tools, tools=tools,
prompt=prompt, prompt=prompt,
callback_manager=callback_manager, **agent_kwargs **agent_kwargs
) )
return AgentExecutor.from_agent_and_tools( return AgentExecutor.from_agent_and_tools(
agent=agent_obj, agent=agent_obj,
tools=tools, tools=tools,
callback_manager=callback_manager,
memory=memory, memory=memory,
tags=tags_, tags=tags_,
**kwargs, **kwargs,
) )

View File

@ -1,5 +1,3 @@
## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍
class ModelContainer: class ModelContainer:
def __init__(self): def __init__(self):
self.MODEL = None self.MODEL = None

View File

@ -3,7 +3,7 @@ from .search_knowledgebase_simple import search_knowledgebase_simple
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
from .calculate import calculate, CalculatorInput from .calculate import calculate, CalculatorInput
from .weather_check import weathercheck, WhetherSchema from .weather_check import weathercheck, WeatherInput
from .shell import shell, ShellInput from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput from .wolfram import wolfram, WolframInput

View File

@ -1,10 +0,0 @@
name: arxiv
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
parameters:
type: object
properties:
query:
type: string
description: The search query title
required:
- query

View File

@ -1,10 +0,0 @@
name: calculate
description: Useful for when you need to answer questions about simple calculations
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -1,10 +0,0 @@
name: search_internet
description: Use this tool to surf internet and get information
parameters:
type: object
properties:
query:
type: string
description: Query for Internet search
required:
- query

View File

@ -1,10 +0,0 @@
name: search_knowledgebase_complex
description: Use this tool to search local knowledgebase and get information
parameters:
type: object
properties:
query:
type: string
description: The query to be searched
required:
- query

View File

@ -1,10 +0,0 @@
name: search_youtube
description: Use this tools to search youtube videos
parameters:
type: object
properties:
query:
type: string
description: Query for Videos search
required:
- query

View File

@ -1,10 +0,0 @@
name: shell
description: Use Linux Shell to execute Linux commands
parameters:
type: object
properties:
query:
type: string
description: The command to execute
required:
- query

View File

@ -1,338 +1,29 @@
from __future__ import annotations
## 单独运行的时候需要添加
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
import requests
from typing import List, Any, Optional
from datetime import datetime
from langchain.prompts import PromptTemplate
from server.agent import model_container
from pydantic import BaseModel, Field
## 使用和风天气API查询天气
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
# key长这样这里提供了示例的key这个key没法使用你需要自己去注册和风天气的账号然后在这里填入你的key
_PROMPT_TEMPLATE = """
用户会提出一个关于天气的问题你的目标是拆分出用户问题中的区 并按照我提供的工具回答
例如 用户提出的问题是: 上海浦东未来1小时天气情况
提取的市和区是: 上海 浦东
如果用户提出的问题是: 上海未来1小时天气情况
提取的市和区是: 上海 None
请注意以下内容:
1. 如果你没有找到区的内容,则一定要使用 None 替代否则程序无法运行
2. 如果用户没有指定市 则直接返回缺少信息
问题: ${{用户的问题}}
你的回答格式应该按照下面的内容请注意格式内的```text 等标记都必须输出这是我用来提取答案的标记
```text
${{拆分的市和区中间用空格隔开}}
```
... weathercheck( )...
```output
${{提取后的答案}}
```
答案: ${{答案}}
这是一个例子
问题: 上海浦东未来1小时天气情况
```text
上海 浦东
```
...weathercheck(上海 浦东)...
```output
预报时间: 1小时后
具体时间: 今天 18:00
温度: 24°C
天气: 多云
风向: 西南风
风速: 7
湿度: 88%
降水概率: 16%
Answer: 上海浦东一小时后的天气是多云
现在这是我的问题
问题: {question}
""" """
PROMPT = PromptTemplate( 更简单的单参数输入工具实现用于查询现在天气的情况
input_variables=["question"], """
template=_PROMPT_TEMPLATE, from pydantic import BaseModel, Field
) import requests
from configs.kb_config import SENIVERSE_API_KEY
def get_city_info(location, adm, key): def weather(location: str, api_key: str):
base_url = 'https://geoapi.qweather.com/v2/city/lookup?' url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
params = {'location': location, 'adm': adm, 'key': key} response = requests.get(url)
response = requests.get(base_url, params=params) if response.status_code == 200:
data = response.json() data = response.json()
return data weather = {
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
return weather
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
def format_weather_data(data, place): def weathercheck(location: str):
hourly_forecast = data['hourly'] return weather(location, SENIVERSE_API_KEY)
formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n"
for forecast in hourly_forecast:
# 将预报时间转换为datetime对象
forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z')
# 获取预报时间的时区
forecast_tz = forecast_time.tzinfo
# 获取当前时间(使用预报时间的时区)
now = datetime.now(forecast_tz)
# 计算预报日期与当前日期的差值
days_diff = (forecast_time.date() - now.date()).days
if days_diff == 0:
forecast_date_str = '今天'
elif days_diff == 1:
forecast_date_str = '明天'
elif days_diff == 2:
forecast_date_str = '后天'
else:
forecast_date_str = str(days_diff) + '天后'
forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M')
# 计算预报时间与当前时间的差值
time_diff = forecast_time - now
# 将差值转换为小时
hours_diff = time_diff.total_seconds() // 3600
if hours_diff < 1:
hours_diff_str = '1小时后'
elif hours_diff >= 24:
# 如果超过24小时转换为天数
days_diff = hours_diff // 24
hours_diff_str = str(int(days_diff)) + ''
else:
hours_diff_str = str(int(hours_diff)) + '小时'
# 将预报时间和当前时间的差值添加到输出中
formatted_data += '预报时间: ' + forecast_time_str + ' 距离现在有: ' + hours_diff_str + '\n'
formatted_data += '温度: ' + forecast['temp'] + '°C\n'
formatted_data += '天气: ' + forecast['text'] + '\n'
formatted_data += '风向: ' + forecast['windDir'] + '\n'
formatted_data += '风速: ' + forecast['windSpeed'] + '\n'
formatted_data += '湿度: ' + forecast['humidity'] + '%\n'
formatted_data += '降水概率: ' + forecast['pop'] + '%\n'
# formatted_data += '降水量: ' + forecast['precip'] + 'mm\n'
formatted_data += '\n'
return formatted_data
def get_weather(key, location_id, place): class WeatherInput(BaseModel):
url = "https://devapi.qweather.com/v7/weather/24h?" location: str = Field(description="City name,include city and county")
params = {
'location': location_id,
'key': key,
}
response = requests.get(url, params=params)
data = response.json()
return format_weather_data(data, place)
def split_query(query):
parts = query.split()
adm = parts[0]
if len(parts) == 1:
return adm, adm
location = parts[1] if parts[1] != 'None' else adm
return location, adm
def weather(query):
location, adm = split_query(query)
key = KEY
if key == "":
return "请先在代码中填入和风天气API Key"
try:
city_info = get_city_info(location=location, adm=adm, key=key)
location_id = city_info['location'][0]['id']
place = adm + "" + location + ""
weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "以上是查询到的天气信息,请你查收\n"
except KeyError:
try:
city_info = get_city_info(location=adm, adm=adm, key=key)
location_id = city_info['location'][0]['id']
place = adm + ""
weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n"
except KeyError:
return "输入的地区不存在,无法提供天气预报"
class LLMWeatherChain(Chain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMWeatherChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, expression: str) -> str:
try:
output = weather(expression)
except Exception as e:
output = "输入的信息有误,请再次尝试"
return output
def _process_llm_result(
self, llm_output: str, run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
llm_output = self.llm_chain.predict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
llm_output = await self.llm_chain.apredict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, _run_manager)
@property
def _chain_type(self) -> str:
return "llm_weather_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMWeatherChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def weathercheck(query: str):
model = model_container.MODEL
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_weather.run(query)
return ans
class WhetherSchema(BaseModel):
location: str = Field(description="应该是一个地区的名称,用空格隔开,例如:上海 浦东,如果没有区的信息,可以只输入上海")
if __name__ == '__main__':
result = weathercheck("苏州姑苏区今晚热不热?")

View File

@ -1,10 +0,0 @@
name: weather_check
description: Use Weather API to get weather information
parameters:
type: object
properties:
query:
type: string
description: City name,include city and county,like "厦门市思明区"
required:
- query

View File

@ -1,10 +0,0 @@
name: wolfram
description: Useful for when you need to calculate difficult math formulas
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -1,8 +1,6 @@
from langchain.tools import Tool from langchain.tools import Tool
from server.agent.tools import * from server.agent.tools import *
## 请注意如果你是为了使用AgentLM在这里你应该使用英文版本。
tools = [ tools = [
Tool.from_function( Tool.from_function(
func=calculate, func=calculate,
@ -20,7 +18,7 @@ tools = [
func=weathercheck, func=weathercheck,
name="weather_check", name="weather_check",
description="", description="",
args_schema=WhetherSchema, args_schema=WeatherInput,
), ),
Tool.from_function( Tool.from_function(
func=shell, func=shell,

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult from langchain.schema import LLMResult

View File

@ -1,22 +1,23 @@
from langchain.memory import ConversationBufferWindowMemory import json
import asyncio
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body from fastapi import Body
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain from langchain.chains import LLMChain
from typing import AsyncIterable, Optional from langchain.memory import ConversationBufferWindowMemory
import asyncio from langchain.agents import LLMSingleActionAgent, AgentExecutor
from typing import List from typing import AsyncIterable, Optional, List
from server.chat.utils import History
import json from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.agent import model_container
from server.knowledge_base.kb_service.base import get_kb_details from server.knowledge_base.kb_service.base import get_kb_details
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from server.chat.utils import History
from server.agent import model_container
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
@ -33,7 +34,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
): ):
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
@ -55,12 +55,10 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
callbacks=[callback], callbacks=[callback],
) )
## 传入全局变量来实现agent调用
kb_list = {x["kb_name"]: x for x in get_kb_details()} kb_list = {x["kb_name"]: x for x in get_kb_details()}
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()} model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
if Agent_MODEL: if Agent_MODEL:
## 如果有指定使用Agent模型来完成任务
model_agent = get_ChatOpenAI( model_agent = get_ChatOpenAI(
model_name=Agent_MODEL, model_name=Agent_MODEL,
temperature=temperature, temperature=temperature,
@ -79,23 +77,17 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
) )
output_parser = CustomOutputParser() output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent) llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
# 把history转成agent的memory
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2) memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history: for message in history:
# 检查消息的角色
if message.role == 'user': if message.role == 'user':
# 添加用户消息
memory.chat_memory.add_user_message(message.content) memory.chat_memory.add_user_message(message.content)
else: else:
# 添加AI消息
memory.chat_memory.add_ai_message(message.content) memory.chat_memory.add_ai_message(message.content)
if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name:
if "chatglm3" in model_container.MODEL.model_name:
agent_executor = initialize_glm3_agent( agent_executor = initialize_glm3_agent(
llm=model, llm=model,
tools=tools, tools=tools,
callback_manager=None, callback_manager=None,
# Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent.
prompt=prompt_template, prompt=prompt_template,
input_variables=["input", "intermediate_steps", "history"], input_variables=["input", "intermediate_steps", "history"],
memory=memory, memory=memory,
@ -155,7 +147,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
answer = "" answer = ""
final_answer = "" final_answer = ""
async for chunk in callback.aiter(): async for chunk in callback.aiter():
# Use server-sent-events to stream the response
data = json.loads(chunk) data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete: if data["status"] == Status.start or data["status"] == Status.complete:
continue continue
@ -181,7 +172,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
await task await task
return EventSourceResponse(agent_chat_iterator(query=query, return EventSourceResponse(agent_chat_iterator(query=query,
history=history, history=history,
model_name=model_name, model_name=model_name,
prompt_name=prompt_name), prompt_name=prompt_name),
) )

View File

@ -1,23 +1,23 @@
from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY, from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from sse_starlette import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Optional, Dict
from server.chat.utils import History
from langchain.docstore.document import Document from langchain.docstore.document import Document
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from sse_starlette import EventSourceResponse
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from server.chat.utils import History
from typing import AsyncIterable
import asyncio
import json import json
from typing import List, Optional, Dict
from strsimpy.normalized_levenshtein import NormalizedLevenshtein from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify from markdownify import markdownify
@ -38,11 +38,11 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
def metaphor_search( def metaphor_search(
text: str, text: str,
result_len: int = SEARCH_ENGINE_TOP_K, result_len: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False, split_result: bool = False,
chunk_size: int = 500, chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE, chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]: ) -> List[Dict]:
from metaphor_python import Metaphor from metaphor_python import Metaphor
@ -58,13 +58,13 @@ def metaphor_search(
# metaphor 返回的内容都是长文本,需要分词再检索 # metaphor 返回的内容都是长文本,需要分词再检索
if split_result: if split_result:
docs = [Document(page_content=x.extract, docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title}) metadata={"link": x.url, "title": x.title})
for x in contents] for x in contents]
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "], text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap) chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs) splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档 # 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len: if len(splitted_docs) > result_len:
normal = NormalizedLevenshtein() normal = NormalizedLevenshtein()
@ -74,13 +74,13 @@ def metaphor_search(
splitted_docs = splitted_docs[:result_len] splitted_docs = splitted_docs[:result_len]
docs = [{"snippet": x.page_content, docs = [{"snippet": x.page_content,
"link": x.metadata["link"], "link": x.metadata["link"],
"title": x.metadata["title"]} "title": x.metadata["title"]}
for x in splitted_docs] for x in splitted_docs]
else: else:
docs = [{"snippet": x.extract, docs = [{"snippet": x.extract,
"link": x.url, "link": x.url,
"title": x.title} "title": x.title}
for x in contents] for x in contents]
return docs return docs
@ -113,25 +113,27 @@ async def lookup_search_engine(
docs = search_result2docs(results) docs = search_result2docs(results)
return docs return docs
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]), async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([], history: List[History] = Body([],
description="历史对话", description="历史对话",
examples=[[ examples=[[
{"role": "user", {"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"}, "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", {"role": "assistant",
"content": "虎头虎脑"}]] "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None,
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), description="限制LLM生成Token数量默认None代表模型最大值"),
split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎") prompt_name: str = Body("default",
): description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False,
description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
):
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -198,9 +200,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
await task await task
return EventSourceResponse(search_engine_chat_iterator(query=query, return EventSourceResponse(search_engine_chat_iterator(query=query,
search_engine_name=search_engine_name, search_engine_name=search_engine_name,
top_k=top_k, top_k=top_k,
history=history, history=history,
model_name=model_name, model_name=model_name,
prompt_name=prompt_name), prompt_name=prompt_name),
) )

View File

@ -83,7 +83,7 @@ def add_file_to_db(session,
kb_file: KnowledgeFile, kb_file: KnowledgeFile,
docs_count: int = 0, docs_count: int = 0,
custom_docs: bool = False, custom_docs: bool = False,
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...] doc_infos: List[Dict] = [], # 形式:[{"id": str, "metadata": dict}, ...]
): ):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb: if kb:

View File

@ -16,7 +16,6 @@ def embed_texts(
) -> BaseResponse: ) -> BaseResponse:
''' '''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]]) 对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
TODO: 也许需要加入缓存机制减少 token 消耗
''' '''
try: try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型 if embed_model in list_embed_models(): # 使用本地Embeddings模型

View File

@ -13,9 +13,9 @@ def list_kbs():
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"), vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse: ) -> BaseResponse:
# Create selected knowledge base # Create selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")
@ -39,8 +39,8 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
def delete_kb( def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"]) knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse: ) -> BaseResponse:
# Delete selected knowledge base # Delete selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me") return BaseResponse(code=403, msg="Don't attack me")

View File

@ -55,8 +55,6 @@ class _FaissPool(CachePool):
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(), embed_device: str = embedding_device(),
) -> FAISS: ) -> FAISS:
# TODO: 整个Embeddings加载逻辑有些混乱待清理
# create an empty vector store
embeddings = EmbeddingsFunAdapter(embed_model) embeddings = EmbeddingsFunAdapter(embed_model)
doc = Document(page_content="init", metadata={}) doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")

View File

@ -95,7 +95,6 @@ def _save_files_in_thread(files: List[UploadFile],
and not override and not override
and os.path.getsize(file_path) == len(file_content) and os.path.getsize(file_path) == len(file_content)
): ):
# TODO: filesize 不同后的处理
file_status = f"文件 {filename} 已存在。" file_status = f"文件 {filename} 已存在。"
logger.warn(file_status) logger.warn(file_status)
return dict(code=404, msg=file_status, data=data) return dict(code=404, msg=file_status, data=data)
@ -116,7 +115,6 @@ def _save_files_in_thread(files: List[UploadFile],
yield result yield result
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), # def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), # knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件"), # override: bool = Form(False, description="覆盖已有文件"),

View File

@ -24,7 +24,7 @@ from server.knowledge_base.utils import (
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,
) )
from typing import List, Union, Dict, Optional from typing import List, Union, Dict, Optional, Tuple
from server.embeddings_api import embed_texts, aembed_texts, embed_documents from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId from server.knowledge_base.model.kb_document_model import DocumentWithVSId
@ -191,7 +191,6 @@ class KBService(ABC):
''' '''
传入参数为 {doc_id: Document, ...} 传入参数为 {doc_id: Document, ...}
如果对应 doc_id 的值为 None或其 page_content 为空则删除该文档 如果对应 doc_id 的值为 None或其 page_content 为空则删除该文档
TODO是否要支持新增 docs
''' '''
self.del_doc_by_ids(list(docs.keys())) self.del_doc_by_ids(list(docs.keys()))
docs = [] docs = []
@ -261,7 +260,7 @@ class KBService(ABC):
query: str, query: str,
top_k: int, top_k: int,
score_threshold: float, score_threshold: float,
) -> List[Document]: ) -> List[Tuple[Document, float]]:
""" """
搜索知识库子类实自己逻辑 搜索知识库子类实自己逻辑
""" """

View File

@ -6,6 +6,7 @@ from langchain.schema import Document
from langchain.vectorstores.elasticsearch import ElasticsearchStore from langchain.vectorstores.elasticsearch import ElasticsearchStore
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile
from server.utils import load_local_embeddings from server.utils import load_local_embeddings
from elasticsearch import Elasticsearch,BadRequestError from elasticsearch import Elasticsearch,BadRequestError
from configs import logger from configs import logger
@ -15,7 +16,7 @@ class ESKBService(KBService):
def do_init(self): def do_init(self):
self.kb_path = self.get_kb_path(self.kb_name) self.kb_path = self.get_kb_path(self.kb_name)
self.index_name = self.kb_path.split("/")[-1] self.index_name = os.path.split(self.kb_path)[-1]
self.IP = kbs_config[self.vs_type()]['host'] self.IP = kbs_config[self.vs_type()]['host']
self.PORT = kbs_config[self.vs_type()]['port'] self.PORT = kbs_config[self.vs_type()]['port']
self.user = kbs_config[self.vs_type()].get("user",'') self.user = kbs_config[self.vs_type()].get("user",'')
@ -38,7 +39,16 @@ class ESKBService(KBService):
raise e raise e
try: try:
# 首先尝试通过es_client_python创建 # 首先尝试通过es_client_python创建
self.es_client_python.indices.create(index=self.index_name) mappings = {
"properties": {
"dense_vector": {
"type": "dense_vector",
"dims": self.dims_length,
"index": True
}
}
}
self.es_client_python.indices.create(index=self.index_name, mappings=mappings)
except BadRequestError as e: except BadRequestError as e:
logger.error("创建索引失败,重新") logger.error("创建索引失败,重新")
logger.error(e) logger.error(e)
@ -80,9 +90,9 @@ class ESKBService(KBService):
except Exception as e: except Exception as e:
logger.error("创建索引失败...") logger.error("创建索引失败...")
logger.error(e) logger.error(e)
# raise e # raise e
@staticmethod @staticmethod
def get_kb_path(knowledge_base_name: str): def get_kb_path(knowledge_base_name: str):
@ -220,7 +230,12 @@ class ESKBService(KBService):
shutil.rmtree(self.kb_path) shutil.rmtree(self.kb_path)
if __name__ == '__main__':
esKBService = ESKBService("test")
#esKBService.clear_vs()
#esKBService.create_kb()
esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test"))
print(esKBService.search_docs("如何启动api服务"))

View File

@ -7,7 +7,7 @@ from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafe
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import torch_gc from server.utils import torch_gc
from langchain.docstore.document import Document from langchain.docstore.document import Document
from typing import List, Dict, Optional from typing import List, Dict, Optional, Tuple
class FaissKBService(KBService): class FaissKBService(KBService):
@ -61,7 +61,7 @@ class FaissKBService(KBService):
query: str, query: str,
top_k: int, top_k: int,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
) -> List[Document]: ) -> List[Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model) embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query) embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:

View File

@ -18,13 +18,10 @@ class MilvusKBService(KBService):
from pymilvus import Collection from pymilvus import Collection
return Collection(milvus_name) return Collection(milvus_name)
# def save_vector_store(self):
# if self.milvus.col:
# self.milvus.col.flush()
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = [] result = []
if self.milvus.col: if self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"]) data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list: for data in data_list:
text = data.pop("text") text = data.pop("text")
@ -73,7 +70,6 @@ class MilvusKBService(KBService):
return score_threshold_process(score_threshold, top_k, docs) return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
# TODO: workaround for bug #10492 in langchain
for doc in docs: for doc in docs:
for k, v in doc.metadata.items(): for k, v in doc.metadata.items():
doc.metadata[k] = str(v) doc.metadata[k] = str(v)

View File

@ -11,25 +11,27 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em
score_threshold_process score_threshold_process
from server.knowledge_base.utils import KnowledgeFile from server.knowledge_base.utils import KnowledgeFile
import shutil import shutil
import sqlalchemy
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
class PGKBService(KBService): class PGKBService(KBService):
pg_vector: PGVector engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
def _load_pg_vector(self): def _load_pg_vector(self):
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model), self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN, distance_strategy=DistanceStrategy.EUCLIDEAN,
connection=PGKBService.engine,
connection_string=kbs_config.get("pg").get("connection_uri")) connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.pg_vector.connect() as connect: with Session(PGKBService.engine) as session:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids") stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
results = [Document(page_content=row[0], metadata=row[1]) for row in results = [Document(page_content=row[0], metadata=row[1]) for row in
connect.execute(stmt, parameters={'ids': ids}).fetchall()] session.execute(stmt, {'ids': ids}).fetchall()]
return results return results
# TODO:
def del_doc_by_ids(self, ids: List[str]) -> bool: def del_doc_by_ids(self, ids: List[str]) -> bool:
return super().del_doc_by_ids(ids) return super().del_doc_by_ids(ids)
@ -43,8 +45,8 @@ class PGKBService(KBService):
return SupportedVSType.PG return SupportedVSType.PG
def do_drop_kb(self): def do_drop_kb(self):
with self.pg_vector.connect() as connect: with Session(PGKBService.engine) as session:
connect.execute(text(f''' session.execute(text(f'''
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录 -- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
DELETE FROM langchain_pg_embedding DELETE FROM langchain_pg_embedding
WHERE collection_id IN ( WHERE collection_id IN (
@ -53,11 +55,10 @@ class PGKBService(KBService):
-- 删除 langchain_pg_collection 表中 记录 -- 删除 langchain_pg_collection 表中 记录
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}'; DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
''')) '''))
connect.commit() session.commit()
shutil.rmtree(self.kb_path) shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_pg_vector()
embed_func = EmbeddingsFunAdapter(self.embed_model) embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query) embeddings = embed_func.embed_query(query)
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
@ -69,13 +70,13 @@ class PGKBService(KBService):
return doc_infos return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect: with Session(PGKBService.engine) as session:
filepath = kb_file.filepath.replace('\\', '\\\\') filepath = kb_file.filepath.replace('\\', '\\\\')
connect.execute( session.execute(
text( text(
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace( ''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
"filepath", filepath))) "filepath", filepath)))
connect.commit() session.commit()
def do_clear_vs(self): def do_clear_vs(self):
self.pg_vector.delete_collection() self.pg_vector.delete_collection()

View File

@ -16,13 +16,10 @@ class ZillizKBService(KBService):
from pymilvus import Collection from pymilvus import Collection
return Collection(zilliz_name) return Collection(zilliz_name)
# def save_vector_store(self):
# if self.zilliz.col:
# self.zilliz.col.flush()
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = [] result = []
if self.zilliz.col: if self.zilliz.col:
# ids = [int(id) for id in ids] # for zilliz if needed #pr 2725
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"]) data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list: for data in data_list:
text = data.pop("text") text = data.pop("text")
@ -50,8 +47,7 @@ class ZillizKBService(KBService):
def _load_zilliz(self): def _load_zilliz(self):
zilliz_args = kbs_config.get("zilliz") zilliz_args = kbs_config.get("zilliz")
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model), self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, connection_args=zilliz_args) collection_name=self.kb_name, connection_args=zilliz_args)
def do_init(self): def do_init(self):
self._load_zilliz() self._load_zilliz()
@ -95,9 +91,7 @@ class ZillizKBService(KBService):
if __name__ == '__main__': if __name__ == '__main__':
from server.db.base import Base, engine from server.db.base import Base, engine
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
zillizService = ZillizKBService("test") zillizService = ZillizKBService("test")

View File

@ -13,7 +13,6 @@ from server.db.repository.knowledge_metadata_repository import add_summary_to_db
from langchain.docstore.document import Document from langchain.docstore.document import Document
# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加
class KBSummaryService(ABC): class KBSummaryService(ABC):
kb_name: str kb_name: str
embed_model: str embed_model: str

View File

@ -112,12 +112,6 @@ class SummaryAdapter:
docs: List[DocumentWithVSId] = []) -> List[Document]: docs: List[DocumentWithVSId] = []) -> List[Document]:
logger.info("start summary") logger.info("start summary")
# TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题
# merge_docs = self._drop_overlap(docs)
# # 将merge_docs中的句子合并成一个文档
# text = self._join_docs(merge_docs)
# 根据段落于句子的分隔符将文档分成chunk每个chunk长度小于token_max长度
""" """
这个过程分成两个部分 这个过程分成两个部分
1. 对每个文档进行处理得到每个文档的摘要 1. 对每个文档进行处理得到每个文档的摘要

View File

@ -91,9 +91,14 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"JSONLoader": [".json"], "JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"], "JSONLinesLoader": [".jsonl"],
"CSVLoader": [".csv"], "CSVLoader": [".csv"],
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持 # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
"RapidOCRPDFLoader": [".pdf"], "RapidOCRPDFLoader": [".pdf"],
"RapidOCRDocLoader": ['.docx', '.doc'],
"RapidOCRPPTLoader": ['.ppt', '.pptx', ],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.epub', '.odt','.tsv'],
"UnstructuredEmailLoader": ['.eml', '.msg'], "UnstructuredEmailLoader": ['.eml', '.msg'],
"UnstructuredEPubLoader": ['.epub'], "UnstructuredEPubLoader": ['.epub'],
"UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'], "UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'],
@ -109,7 +114,6 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredXMLLoader": ['.xml'], "UnstructuredXMLLoader": ['.xml'],
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'], "UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
"EverNoteLoader": ['.enex'], "EverNoteLoader": ['.enex'],
"UnstructuredFileLoader": ['.txt'],
} }
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
@ -141,15 +145,14 @@ def get_LoaderClass(file_extension):
if file_extension in extensions: if file_extension in extensions:
return LoaderClass return LoaderClass
# 把一些向量化共用逻辑从KnowledgeFile抽取出来等langchain支持内存文件的时候可以将非磁盘文件向量化
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
''' '''
根据loader_name和文件路径或内容返回文档加载器 根据loader_name和文件路径或内容返回文档加载器
''' '''
loader_kwargs = loader_kwargs or {} loader_kwargs = loader_kwargs or {}
try: try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]: if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader",
"RapidOCRDocLoader", "RapidOCRPPTLoader"]:
document_loaders_module = importlib.import_module('document_loaders') document_loaders_module = importlib.import_module('document_loaders')
else: else:
document_loaders_module = importlib.import_module('langchain.document_loaders') document_loaders_module = importlib.import_module('langchain.document_loaders')
@ -171,7 +174,6 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
if encode_detect is None: if encode_detect is None:
encode_detect = {"encoding": "utf-8"} encode_detect = {"encoding": "utf-8"}
loader_kwargs["encoding"] = encode_detect["encoding"] loader_kwargs["encoding"] = encode_detect["encoding"]
## TODO支持更多的自定义CSV读取逻辑
elif loader_name == "JSONLoader": elif loader_name == "JSONLoader":
loader_kwargs.setdefault("jq_schema", ".") loader_kwargs.setdefault("jq_schema", ".")
@ -259,6 +261,10 @@ def make_text_splitter(
text_splitter_module = importlib.import_module('langchain.text_splitter') text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
# If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
# text_splitter._tokenizer.max_length = 37016792
# text_splitter._tokenizer.prefer_gpu()
return text_splitter return text_splitter

View File

@ -0,0 +1,51 @@
from typing import (
TYPE_CHECKING,
Any,
Tuple
)
import sys
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
import tiktoken
class MinxChatOpenAI:
@staticmethod
def import_tiktoken() -> Any:
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install tiktoken`."
)
return tiktoken
@staticmethod
def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]:
tiktoken_ = MinxChatOpenAI.import_tiktoken()
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model(model)
except Exception as e:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding

View File

@ -8,3 +8,4 @@ from .qwen import QwenWorker
from .baichuan import BaiChuanWorker from .baichuan import BaiChuanWorker
from .azure import AzureWorker from .azure import AzureWorker
from .tiangong import TianGongWorker from .tiangong import TianGongWorker
from .gemini import GeminiWorker

View File

@ -67,12 +67,10 @@ class AzureWorker(ApiModelWorker):
self.logger.error(f"请求 Azure API 时发生错误:{resp}") self.logger.error(f"请求 Azure API 时发生错误:{resp}")
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="You are a helpful, respectful and honest assistant.", system_message="You are a helpful, respectful and honest assistant.",

View File

@ -88,12 +88,10 @@ class BaiChuanWorker(ApiModelWorker):
yield data yield data
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="", system_message="",

View File

@ -125,8 +125,6 @@ class ApiModelWorker(BaseModelWorker):
def count_token(self, params): def count_token(self, params):
# TODO需要完善
# print("count token")
prompt = params["prompt"] prompt = params["prompt"]
return {"count": len(str(prompt)), "error_code": 0} return {"count": len(str(prompt)), "error_code": 0}

View File

@ -12,16 +12,16 @@ class FangZhouWorker(ApiModelWorker):
""" """
def __init__( def __init__(
self, self,
*, *,
model_names: List[str] = ["fangzhou-api"], model_names: List[str] = ["fangzhou-api"],
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
version: Literal["chatglm-6b-model"] = "chatglm-6b-model", version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小 kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version
@ -53,15 +53,15 @@ class FangZhouWorker(ApiModelWorker):
if error := resp.error: if error := resp.error:
if error.code_n > 0: if error.code_n > 0:
data = { data = {
"error_code": error.code_n, "error_code": error.code_n,
"text": error.message, "text": error.message,
"error": { "error": {
"message": error.message, "message": error.message,
"type": "invalid_request_error", "type": "invalid_request_error",
"param": None, "param": None,
"code": None, "code": None,
}
} }
}
self.logger.error(f"请求方舟 API 时发生错误:{data}") self.logger.error(f"请求方舟 API 时发生错误:{data}")
yield data yield data
elif chunk := resp.choice.message.content: elif chunk := resp.choice.message.content:
@ -77,7 +77,6 @@ class FangZhouWorker(ApiModelWorker):
break break
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding") print("embedding")
print(params) print(params)

View File

@ -0,0 +1,123 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import json, httpx
from typing import List, Dict
from configs import logger, log_verbose
class GeminiWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["gemini-api"],
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs)
def create_gemini_messages(self, messages) -> json:
has_history = any(msg['role'] == 'assistant' for msg in messages)
gemini_msg = []
for msg in messages:
role = msg['role']
content = msg['content']
if role == 'system':
continue
if has_history:
if role == 'assistant':
role = "model"
transformed_msg = {"role": role, "parts": [{"text": content}]}
else:
if role == 'user':
transformed_msg = {"parts": [{"text": content}]}
gemini_msg.append(transformed_msg)
msg = dict(contents=gemini_msg)
return msg
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
data = self.create_gemini_messages(messages=params.messages)
generationConfig = dict(
temperature=params.temperature,
topK=1,
topP=1,
maxOutputTokens=4096,
stopSequences=[]
)
data['generationConfig'] = generationConfig
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key
headers = {
'Content-Type': 'application/json',
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
logger.info(f'{self.__class__.__name__}:data: {data}')
text = ""
json_string = ""
timeout = httpx.Timeout(60.0)
client = get_httpx_client(timeout=timeout)
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()
if not line or "[DONE]" in line:
continue
json_string += line
try:
resp = json.loads(json_string)
if 'candidates' in resp:
for candidate in resp['candidates']:
content = candidate.get('content', {})
parts = content.get('parts', [])
for part in parts:
if 'text' in part:
text += part['text']
yield {
"error_code": 0,
"text": text
}
print(text)
except json.JSONDecodeError as e:
print("Failed to decode JSON:", e)
print("Invalid JSON string:", json_string)
def get_embeddings(self, params):
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="You are a helpful, respectful and honest assistant.",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = GeminiWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21012",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21012)

View File

@ -28,7 +28,7 @@ class MiniMaxWorker(ApiModelWorker):
def validate_messages(self, messages: List[Dict]) -> List[Dict]: def validate_messages(self, messages: List[Dict]) -> List[Dict]:
role_maps = { role_maps = {
"user": self.user_role, "USER": self.user_role,
"assistant": self.ai_role, "assistant": self.ai_role,
"system": "system", "system": "system",
} }
@ -37,7 +37,6 @@ class MiniMaxWorker(ApiModelWorker):
def do_chat(self, params: ApiChatParams) -> Dict: def do_chat(self, params: ApiChatParams) -> Dict:
# 按照官网推荐直接调用abab 5.5模型 # 按照官网推荐直接调用abab 5.5模型
# TODO: 支持指定回复要求支持指定用户名称、AI名称
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}' url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
@ -55,7 +54,7 @@ class MiniMaxWorker(ApiModelWorker):
"temperature": params.temperature, "temperature": params.temperature,
"top_p": params.top_p, "top_p": params.top_p,
"tokens_to_generate": params.max_tokens or 1024, "tokens_to_generate": params.max_tokens or 1024,
# TODO: 以下参数为minimax特有传入空值会出错。 # 以下参数为minimax特有传入空值会出错。
# "prompt": params.system_message or self.conv.system_message, # "prompt": params.system_message or self.conv.system_message,
# "bot_setting": [], # "bot_setting": [],
# "role_meta": params.role_meta, # "role_meta": params.role_meta,
@ -140,15 +139,13 @@ class MiniMaxWorker(ApiModelWorker):
self.logger.error(f"请求 MiniMax API 时发生错误:{data}") self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data return data
i += batch_size i += batch_size
return {"code": 200, "data": embeddings} return {"code": 200, "data": result}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="你是MiniMax自主研发的大型语言模型回答问题简洁有条理。", system_message="你是MiniMax自主研发的大型语言模型回答问题简洁有条理。",

View File

@ -84,30 +84,6 @@ class QianFanWorker(ApiModelWorker):
def do_chat(self, params: ApiChatParams) -> Dict: def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
# import qianfan
# comp = qianfan.ChatCompletion(model=params.version,
# endpoint=params.version_url,
# ak=params.api_key,
# sk=params.secret_key,)
# text = ""
# for resp in comp.do(messages=params.messages,
# temperature=params.temperature,
# top_p=params.top_p,
# stream=True):
# if resp.code == 200:
# if chunk := resp.body.get("result"):
# text += chunk
# yield {
# "error_code": 0,
# "text": text
# }
# else:
# yield {
# "error_code": resp.code,
# "text": str(resp.body),
# }
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \ BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}' '/{model_version}?access_token={access_token}'
@ -190,19 +166,19 @@ class QianFanWorker(ApiModelWorker):
i = 0 i = 0
batch_size = 10 batch_size = 10
while i < len(params.texts): while i < len(params.texts):
texts = params.texts[i:i+batch_size] texts = params.texts[i:i + batch_size]
resp = client.post(url, json={"input": texts}).json() resp = client.post(url, json={"input": texts}).json()
if "error_code" in resp: if "error_code" in resp:
data = { data = {
"code": resp["error_code"], "code": resp["error_code"],
"msg": resp["error_msg"], "msg": resp["error_msg"],
"error": { "error": {
"message": resp["error_msg"], "message": resp["error_msg"],
"type": "invalid_request_error", "type": "invalid_request_error",
"param": None, "param": None,
"code": None, "code": None,
} }
} }
self.logger.error(f"请求千帆 API 时发生错误:{data}") self.logger.error(f"请求千帆 API 时发生错误:{data}")
return data return data
else: else:
@ -211,14 +187,11 @@ class QianFanWorker(ApiModelWorker):
i += batch_size i += batch_size
return {"code": 200, "data": result} return {"code": 200, "data": result}
# TODO: qianfan支持续写模型
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务", system_message="你是一个聪明的助手,请根据用户的提示来完成任务",

View File

@ -100,12 +100,10 @@ class QwenWorker(ApiModelWorker):
return {"code": 200, "data": result} return {"code": 200, "data": result}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",

View File

@ -11,16 +11,15 @@ from typing import List, Literal, Dict
import requests import requests
class TianGongWorker(ApiModelWorker): class TianGongWorker(ApiModelWorker):
def __init__( def __init__(
self, self,
*, *,
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
model_names: List[str] = ["tiangong-api"], model_names: List[str] = ["tiangong-api"],
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse", version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768) kwargs.setdefault("context_len", 32768)
@ -34,18 +33,18 @@ class TianGongWorker(ApiModelWorker):
data = { data = {
"messages": params.messages, "messages": params.messages,
"model": "SkyChat-MegaVerse" "model": "SkyChat-MegaVerse"
} }
timestamp = str(int(time.time())) timestamp = str(int(time.time()))
sign_content = params.api_key + params.secret_key + timestamp sign_content = params.api_key + params.secret_key + timestamp
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest() sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
headers={ headers = {
"app_key": params.api_key, "app_key": params.api_key,
"timestamp": timestamp, "timestamp": timestamp,
"sign": sign_result, "sign": sign_result,
"Content-Type": "application/json", "Content-Type": "application/json",
"stream": "true" # or change to "false" 不处理流式返回内容 "stream": "true" # or change to "false" 不处理流式返回内容
} }
# 发起请求并获取响应 # 发起请求并获取响应
response = requests.post(url, headers=headers, json=data, stream=True) response = requests.post(url, headers=headers, json=data, stream=True)
@ -56,27 +55,25 @@ class TianGongWorker(ApiModelWorker):
# 处理接收到的数据 # 处理接收到的数据
# print(line.decode('utf-8')) # print(line.decode('utf-8'))
resp = json.loads(line) resp = json.loads(line)
if resp["code"] == 200: if resp["code"] == 200:
text += resp['resp_data']['reply'] text += resp['resp_data']['reply']
yield { yield {
"error_code": 0, "error_code": 0,
"text": text "text": text
} }
else: else:
data = { data = {
"error_code": resp["code"], "error_code": resp["code"],
"text": resp["code_msg"] "text": resp["code_msg"]
} }
self.logger.error(f"请求天工 API 时出错:{data}") self.logger.error(f"请求天工 API 时出错:{data}")
yield data yield data
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="", system_message="",
@ -85,5 +82,3 @@ class TianGongWorker(ApiModelWorker):
sep="\n### ", sep="\n### ",
stop_str="###", stop_str="###",
) )

View File

@ -37,12 +37,11 @@ class XingHuoWorker(ApiModelWorker):
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000需要自行修改 kwargs.setdefault("context_len", 8000)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version
def do_chat(self, params: ApiChatParams) -> Dict: def do_chat(self, params: ApiChatParams) -> Dict:
# TODO: 当前每次对话都要重新连接websocket确认是否可以保持连接
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
version_mapping = { version_mapping = {
@ -73,12 +72,10 @@ class XingHuoWorker(ApiModelWorker):
yield {"error_code": 0, "text": text} yield {"error_code": 0, "text": text}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding") print("embedding")
print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# TODO: 确认模板是否需要修改
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务", system_message="你是一个聪明的助手,请根据用户的提示来完成任务",

View File

@ -4,93 +4,89 @@ from fastchat import conversation as conv
import sys import sys
from typing import List, Dict, Iterator, Literal from typing import List, Dict, Iterator, Literal
from configs import logger, log_verbose from configs import logger, log_verbose
import requests
import jwt
import time
import json
def generate_token(apikey: str, exp_seconds: int):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
class ChatGLMWorker(ApiModelWorker): class ChatGLMWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text_embedding"
def __init__( def __init__(
self, self,
*, *,
model_names: List[str] = ["zhipu-api"], model_names: List[str] = ["zhipu-api"],
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
version: Literal["chatglm_turbo"] = "chatglm_turbo", version: Literal["chatglm_turbo"] = "chatglm_turbo",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768) kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]: def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
# TODO: 维护request_id
import zhipuai
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key token = generate_token(params.api_key, 60)
headers = {
if log_verbose: "Content-Type": "application/json",
logger.info(f'{self.__class__.__name__}:params: {params}') "Authorization": f"Bearer {token}"
}
response = zhipuai.model_api.sse_invoke( data = {
model=params.version, "model": params.version,
prompt=params.messages, "messages": params.messages,
temperature=params.temperature, "max_tokens": params.max_tokens,
top_p=params.top_p, "temperature": params.temperature,
incremental=False, "stream": False
) }
for e in response.events(): url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
if e.event == "add": response = requests.post(url, headers=headers, json=data)
yield {"error_code": 0, "text": e.data} # for chunk in response.iter_lines():
elif e.event in ["error", "interrupted"]: # if chunk:
data = { # chunk_str = chunk.decode('utf-8')
"error_code": 500, # json_start_pos = chunk_str.find('{"id"')
"text": e.data, # if json_start_pos != -1:
"error": { # json_str = chunk_str[json_start_pos:]
"message": e.data, # json_data = json.loads(json_str)
"type": "invalid_request_error", # for choice in json_data.get('choices', []):
"param": None, # delta = choice.get('delta', {})
"code": None, # content = delta.get('content', '')
} # yield {"error_code": 0, "text": content}
} ans = response.json()
self.logger.error(f"请求智谱 API 时发生错误:{data}") content = ans["choices"][0]["message"]["content"]
yield data yield {"error_code": 0, "text": content}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import zhipuai
params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key
embeddings = []
try:
for t in params.texts:
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
if response["code"] == 200:
embeddings.append(response["data"]["embedding"])
else:
self.logger.error(f"请求智谱 API 时发生错误:{response}")
return response # dict with code & msg
except Exception as e:
self.logger.error(f"请求智谱 API 时发生错误:{data}")
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return data
return {"code": 200, "data": embeddings}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # 临时解决方案不支持embedding
print("embedding") print("embedding")
# print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# 这里的是chatglm api的模板其它API的conv_template需要定制
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务", system_message="你是智谱AI小助手请根据用户的提示来完成任务",
messages=[], messages=[],
roles=["Human", "Assistant", "System"], roles=["user", "assistant", "system"],
sep="\n###", sep="\n###",
stop_str="###", stop_str="###",
) )

File diff suppressed because one or more lines are too long

View File

@ -12,10 +12,23 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI from langchain.llms import OpenAI
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple from typing import (
TYPE_CHECKING,
Literal,
Optional,
Callable,
Generator,
Dict,
Any,
Awaitable,
Union,
Tuple
)
import logging import logging
import torch import torch
from server.minx_chat_openai import MinxChatOpenAI
async def wrap_done(fn: Awaitable, event: asyncio.Event): async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised.""" """Wrap an awaitable with a event to signal when it's done or an exception is raised."""
@ -23,7 +36,6 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
await fn await fn
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
# TODO: handle exception
msg = f"Caught exception: {e}" msg = f"Caught exception: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
@ -44,7 +56,7 @@ def get_ChatOpenAI(
config = get_model_worker_config(model_name) config = get_model_worker_config(model_name)
if model_name == "openai-api": if model_name == "openai-api":
model_name = config.get("model_name") model_name = config.get("model_name")
ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model
model = ChatOpenAI( model = ChatOpenAI(
streaming=streaming, streaming=streaming,
verbose=verbose, verbose=verbose,
@ -153,6 +165,7 @@ class ChatMessage(BaseModel):
def torch_gc(): def torch_gc():
try: try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
# with torch.cuda.device(DEVICE): # with torch.cuda.device(DEVICE):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -390,7 +403,7 @@ def fschat_controller_address() -> str:
def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str: def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
if model := get_model_worker_config(model_name): # TODO: depends fastchat if model := get_model_worker_config(model_name):
host = model["host"] host = model["host"]
if host == "0.0.0.0": if host == "0.0.0.0":
host = "127.0.0.1" host = "127.0.0.1"
@ -435,7 +448,7 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
from configs import prompt_config from configs import prompt_config
import importlib import importlib
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载 importlib.reload(prompt_config)
return prompt_config.PROMPT_TEMPLATES[type].get(name) return prompt_config.PROMPT_TEMPLATES[type].get(name)
@ -489,69 +502,36 @@ def set_httpx_config(
no_proxy.append(host) no_proxy.append(host)
os.environ["NO_PROXY"] = ",".join(no_proxy) os.environ["NO_PROXY"] = ",".join(no_proxy)
# TODO: 简单的清除系统代理不是个好的选择影响太多。似乎修改代理服务器的bypass列表更好。
# patch requests to use custom proxies instead of system settings
def _get_proxies(): def _get_proxies():
return proxies return proxies
import urllib.request import urllib.request
urllib.request.getproxies = _get_proxies urllib.request.getproxies = _get_proxies
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def is_mps_available():
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def is_cuda_available():
return torch.cuda.is_available()
def detect_device() -> Literal["cuda", "mps", "cpu"]: def detect_device() -> Literal["cuda", "mps", "cpu"]:
try: try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
return "cuda" return "cuda"
if is_mps_available(): if torch.backends.mps.is_available():
return "mps" return "mps"
except: except:
pass pass
return "cpu" return "cpu"
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or LLM_DEVICE device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu", "xpu"]: if device not in ["cuda", "mps", "cpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device() device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
logging.warning("mps is not available, fallback to cuda")
return "cuda"
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu", "xpu"]:
return detect_device()
return device return device
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]: def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or LLM_DEVICE device = device or EMBEDDING_DEVICE
if device not in ["cuda", "mps", "cpu"]: if device not in ["cuda", "mps", "cpu"]:
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
device = detect_device() device = detect_device()
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
logging.warning("cuda is not available, fallback to mps")
return "mps"
if device == 'mps' and not is_mps_available() and is_cuda_available():
logging.warning("mps is not available, fallback to cuda")
return "cuda"
# auto detect device if not specified
if device not in ["cuda", "mps", "cpu"]:
return detect_device()
return device return device
@ -569,7 +549,7 @@ def run_in_thread_pool(
thread = pool.submit(func, **kwargs) thread = pool.submit(func, **kwargs)
tasks.append(thread) tasks.append(thread)
for obj in as_completed(tasks): # TODO: Ctrl+c无法停止 for obj in as_completed(tasks):
yield obj.result() yield obj.result()

View File

@ -6,9 +6,8 @@ import sys
from multiprocessing import Process from multiprocessing import Process
from datetime import datetime from datetime import datetime
from pprint import pprint from pprint import pprint
from langchain_core._api import deprecated
# 设置numexpr最大线程数默认为CPU核心数
try: try:
import numexpr import numexpr
@ -33,15 +32,18 @@ from configs import (
HTTPX_DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT,
) )
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_config, get_httpx_client, fschat_openai_api_address, get_httpx_client, get_model_worker_config,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device) MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
from server.knowledge_base.migrate import create_tables from server.knowledge_base.migrate import create_tables
import argparse import argparse
from typing import Tuple, List, Dict from typing import List, Dict
from configs import VERSION from configs import VERSION
@deprecated(
since="0.3.0",
message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动0.2.x中相关功能将废弃",
removal="0.3.0")
def create_controller_app( def create_controller_app(
dispatch_method: str, dispatch_method: str,
log_level: str = "INFO", log_level: str = "INFO",
@ -88,7 +90,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(args, k, v) setattr(args, k, v)
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作 if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
from fastchat.serve.base_model_worker import app from fastchat.serve.base_model_worker import app
worker = "" worker = ""
# 在线模型API # 在线模型API
@ -107,12 +109,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
import fastchat.serve.vllm_worker import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 args.tokenizer = args.model_path
args.tokenizer_mode = 'auto' args.tokenizer_mode = 'auto'
args.trust_remote_code= True args.trust_remote_code = True
args.download_dir= None args.download_dir = None
args.load_format = 'auto' args.load_format = 'auto'
args.dtype = 'auto' args.dtype = 'auto'
args.seed = 0 args.seed = 0
@ -122,13 +124,13 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.block_size = 16 args.block_size = 16
args.swap_space = 4 # GiB args.swap_space = 4 # GiB
args.gpu_memory_utilization = 0.90 args.gpu_memory_utilization = 0.90
args.max_num_batched_tokens = None # 一个批次中的最大令牌tokens数量这个取决于你的显卡和大模型设置设置太大显存会不够 args.max_num_batched_tokens = None # 一个批次中的最大令牌tokens数量这个取决于你的显卡和大模型设置设置太大显存会不够
args.max_num_seqs = 256 args.max_num_seqs = 256
args.disable_log_stats = False args.disable_log_stats = False
args.conv_template = None args.conv_template = None
args.limit_worker_concurrency = 5 args.limit_worker_concurrency = 5
args.no_register = False args.no_register = False
args.num_gpus = 1 # vllm worker的切分是tensor并行这里填写显卡的数量 args.num_gpus = 1 # vllm worker的切分是tensor并行这里填写显卡的数量
args.engine_use_ray = False args.engine_use_ray = False
args.disable_log_requests = False args.disable_log_requests = False
@ -138,10 +140,10 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.quantization = None args.quantization = None
args.max_log_len = None args.max_log_len = None
args.tokenizer_revision = None args.tokenizer_revision = None
# 0.2.2 vllm需要新加的参数 # 0.2.2 vllm需要新加的参数
args.max_paddings = 256 args.max_paddings = 256
if args.model_path: if args.model_path:
args.model = args.model_path args.model = args.model_path
if args.num_gpus > 1: if args.num_gpus > 1:
@ -154,16 +156,16 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
worker = VLLMWorker( worker = VLLMWorker(
controller_addr = args.controller_address, controller_addr=args.controller_address,
worker_addr = args.worker_address, worker_addr=args.worker_address,
worker_id = worker_id, worker_id=worker_id,
model_path = args.model_path, model_path=args.model_path,
model_names = args.model_names, model_names=args.model_names,
limit_worker_concurrency = args.limit_worker_concurrency, limit_worker_concurrency=args.limit_worker_concurrency,
no_register = args.no_register, no_register=args.no_register,
llm_engine = engine, llm_engine=engine,
conv_template = args.conv_template, conv_template=args.conv_template,
) )
sys.modules["fastchat.serve.vllm_worker"].engine = engine sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker sys.modules["fastchat.serve.vllm_worker"].worker = worker
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level) sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
@ -171,7 +173,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
else: else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3" args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3"
args.max_gpu_memory = "22GiB" args.max_gpu_memory = "22GiB"
args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量 args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量
@ -325,7 +327,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
with get_httpx_client() as client: with get_httpx_client() as client:
r = client.post(worker_address + "/release", r = client.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin}) json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200: if r.status_code != 200:
msg = f"failed to release model: {model_name}" msg = f"failed to release model: {model_name}"
logger.error(msg) logger.error(msg)
@ -393,8 +395,8 @@ def run_model_worker(
# add interface to release and load model # add interface to release and load model
@app.post("/release") @app.post("/release")
def release_model( def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"), new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型") keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict: ) -> Dict:
if keep_origin: if keep_origin:
if new_model_name: if new_model_name:
@ -416,7 +418,7 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
set_httpx_config() set_httpx_config()
controller_addr = fschat_controller_address() controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet. app = create_openai_api_app(controller_addr, log_level=log_level)
_set_app_event(app, started_event) _set_app_event(app, started_event)
host = FSCHAT_OPENAI_API["host"] host = FSCHAT_OPENAI_API["host"]
@ -450,13 +452,13 @@ def run_webui(started_event: mp.Event = None, run_mode: str = None):
port = WEBUI_SERVER["port"] port = WEBUI_SERVER["port"]
cmd = ["streamlit", "run", "webui.py", cmd = ["streamlit", "run", "webui.py",
"--server.address", host, "--server.address", host,
"--server.port", str(port), "--server.port", str(port),
"--theme.base", "light", "--theme.base", "light",
"--theme.primaryColor", "#165dff", "--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5", "--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000", "--theme.textColor", "#000000",
] ]
if run_mode == "lite": if run_mode == "lite":
cmd += [ cmd += [
"--", "--",
@ -605,8 +607,10 @@ async def start_main_server():
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed. Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose. Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
""" """
def f(signal_received, frame): def f(signal_received, frame):
raise KeyboardInterrupt(f"{signalname} received") raise KeyboardInterrupt(f"{signalname} received")
return f return f
# This will be inherited by the child process if it is forked (not spawned) # This will be inherited by the child process if it is forked (not spawned)
@ -701,8 +705,8 @@ async def start_main_server():
for model_name in args.model_name: for model_name in args.model_name:
config = get_model_worker_config(model_name) config = get_model_worker_config(model_name)
if (config.get("online_api") if (config.get("online_api")
and config.get("worker_class") and config.get("worker_class")
and model_name in FSCHAT_MODEL_WORKERS): and model_name in FSCHAT_MODEL_WORKERS):
e = manager.Event() e = manager.Event()
model_worker_started.append(e) model_worker_started.append(e)
process = Process( process = Process(
@ -742,12 +746,12 @@ async def start_main_server():
else: else:
try: try:
# 保证任务收到SIGINT后能够正常退出 # 保证任务收到SIGINT后能够正常退出
if p:= processes.get("controller"): if p := processes.get("controller"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
controller_started.wait() # 等待controller启动完成 controller_started.wait() # 等待controller启动完成
if p:= processes.get("openai_api"): if p := processes.get("openai_api"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
@ -763,24 +767,24 @@ async def start_main_server():
for e in model_worker_started: for e in model_worker_started:
e.wait() e.wait()
if p:= processes.get("api"): if p := processes.get("api"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
api_started.wait() # 等待api.py启动完成 api_started.wait() # 等待api.py启动完成
if p:= processes.get("webui"): if p := processes.get("webui"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
webui_started.wait() # 等待webui.py启动完成 webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args) dump_server_info(after_start=True, args=args)
while True: while True:
cmd = queue.get() # 收到切换模型的消息 cmd = queue.get() # 收到切换模型的消息
e = manager.Event() e = manager.Event()
if isinstance(cmd, list): if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd model_name, cmd, new_model_name = cmd
if cmd == "start": # 运行新模型 if cmd == "start": # 运行新模型
logger.info(f"准备启动新模型进程:{new_model_name}") logger.info(f"准备启动新模型进程:{new_model_name}")
process = Process( process = Process(
target=run_model_worker, target=run_model_worker,
@ -831,7 +835,6 @@ async def start_main_server():
else: else:
logger.error(f"未找到模型进程:{model_name}") logger.error(f"未找到模型进程:{model_name}")
# for process in processes.get("model_worker", {}).values(): # for process in processes.get("model_worker", {}).values():
# process.join() # process.join()
# for process in processes.get("online_api", {}).values(): # for process in processes.get("online_api", {}).values():
@ -866,10 +869,9 @@ async def start_main_server():
for p in processes.values(): for p in processes.values():
logger.info("Process status: %s", p) logger.info("Process status: %s", p)
if __name__ == "__main__":
# 确保数据库表被创建
create_tables()
if __name__ == "__main__":
create_tables()
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
else: else:
@ -879,16 +881,15 @@ if __name__ == "__main__":
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
# 同步调用协程代码
loop.run_until_complete(start_main_server())
loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例: # 服务启动后接口调用示例:
# import openai # import openai
# openai.api_key = "EMPTY" # Not support yet # openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1" # openai.api_base = "http://localhost:8888/v1"
# model = "chatglm2-6b" # model = "chatglm3-6b"
# # create a chat completion # # create a chat completion
# completion = openai.ChatCompletion.create( # completion = openai.ChatCompletion.create(

BIN
tests/samples/ocr_test.docx Normal file

Binary file not shown.

BIN
tests/samples/ocr_test.pptx Normal file

Binary file not shown.

View File

@ -6,13 +6,12 @@ from datetime import datetime
import os import os
import re import re
import time import time
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, LLM_MODELS,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL) DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from server.knowledge_base.utils import LOADER_DICT from server.knowledge_base.utils import LOADER_DICT
import uuid import uuid
from typing import List, Dict from typing import List, Dict
chat_box = ChatBox( chat_box = ChatBox(
assistant_avatar=os.path.join( assistant_avatar=os.path.join(
"img", "img",
@ -127,7 +126,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.use_chat_name(conversation_name) chat_box.use_chat_name(conversation_name)
conversation_id = st.session_state["conversation_ids"][conversation_name] conversation_id = st.session_state["conversation_ids"][conversation_name]
# TODO: 对话模型与会话绑定
def on_mode_change(): def on_mode_change():
mode = st.session_state.dialogue_mode mode = st.session_state.dialogue_mode
text = f"已切换到 {mode} 模式。" text = f"已切换到 {mode} 模式。"
@ -138,11 +136,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.toast(text) st.toast(text)
dialogue_modes = ["LLM 对话", dialogue_modes = ["LLM 对话",
"知识库问答", "知识库问答",
"文件对话", "文件对话",
"搜索引擎问答", "搜索引擎问答",
"自定义Agent问答", "自定义Agent问答",
] ]
dialogue_mode = st.selectbox("请选择对话模式:", dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes, dialogue_modes,
index=0, index=0,
@ -166,12 +164,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
available_models = [] available_models = []
config_models = api.list_config_models() config_models = api.list_config_models()
if not is_lite: if not is_lite:
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型 for k, v in config_models.get("local", {}).items():
if (v.get("model_path_exists") if (v.get("model_path_exists")
and k not in running_models): and k not in running_models):
available_models.append(k) available_models.append(k)
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型 for k, v in config_models.get("online", {}).items():
if not v.get("provider") and k not in running_models: if not v.get("provider") and k not in running_models and k in LLM_MODELS:
available_models.append(k) available_models.append(k)
llm_models = running_models + available_models llm_models = running_models + available_models
cur_llm_model = st.session_state.get("cur_llm_model", default_model) cur_llm_model = st.session_state.get("cur_llm_model", default_model)
@ -250,14 +248,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
elif dialogue_mode == "文件对话": elif dialogue_mode == "文件对话":
with st.expander("文件对话配置", True): with st.expander("文件对话配置", True):
files = st.file_uploader("上传知识文件:", files = st.file_uploader("上传知识文件:",
[i for ls in LOADER_DICT.values() for i in ls], [i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True, accept_multiple_files=True,
) )
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
## Bge 模型会超过1 ## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
if st.button("开始上传", disabled=len(files)==0): if st.button("开始上传", disabled=len(files) == 0):
st.session_state["file_chat_id"] = upload_temp_docs(files, api) st.session_state["file_chat_id"] = upload_temp_docs(files, api)
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines() search_engine_list = api.list_search_engines()
@ -279,9 +277,9 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 " chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。输入/help查看自定义命令 "
def on_feedback( def on_feedback(
feedback, feedback,
message_id: str = "", message_id: str = "",
history_index: int = -1, history_index: int = -1,
): ):
reason = feedback["text"] reason = feedback["text"]
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
@ -296,7 +294,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
} }
if prompt := st.chat_input(chat_input_placeholder, key="prompt"): if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令 if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
st.rerun() st.rerun()
else: else:
history = get_messages_history(history_len) history = get_messages_history(history_len)
@ -306,11 +304,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
text = "" text = ""
message_id = "" message_id = ""
r = api.chat_chat(prompt, r = api.chat_chat(prompt,
history=history, history=history,
conversation_id=conversation_id, conversation_id=conversation_id,
model=llm_model, model=llm_model,
prompt_name=prompt_template_name, prompt_name=prompt_template_name,
temperature=temperature) temperature=temperature)
for t in r: for t in r:
if error_msg := check_error_msg(t): # check whether error occured if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg) st.error(error_msg)
@ -321,12 +319,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
metadata = { metadata = {
"message_id": message_id, "message_id": message_id,
} }
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标 chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
chat_box.show_feedback(**feedback_kwargs, chat_box.show_feedback(**feedback_kwargs,
key=message_id, key=message_id,
on_submit=on_feedback, on_submit=on_feedback,
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
elif dialogue_mode == "自定义Agent问答": elif dialogue_mode == "自定义Agent问答":
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL): if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
@ -373,13 +371,13 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
]) ])
text = "" text = ""
for d in api.knowledge_base_chat(prompt, for d in api.knowledge_base_chat(prompt,
knowledge_base_name=selected_kb, knowledge_base_name=selected_kb,
top_k=kb_top_k, top_k=kb_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
history=history, history=history,
model=llm_model, model=llm_model,
prompt_name=prompt_template_name, prompt_name=prompt_template_name,
temperature=temperature): temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
elif chunk := d.get("answer"): elif chunk := d.get("answer"):
@ -397,13 +395,13 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
]) ])
text = "" text = ""
for d in api.file_chat(prompt, for d in api.file_chat(prompt,
knowledge_id=st.session_state["file_chat_id"], knowledge_id=st.session_state["file_chat_id"],
top_k=kb_top_k, top_k=kb_top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
history=history, history=history,
model=llm_model, model=llm_model,
prompt_name=prompt_template_name, prompt_name=prompt_template_name,
temperature=temperature): temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
elif chunk := d.get("answer"): elif chunk := d.get("answer"):
@ -455,4 +453,4 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md", file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
mime="text/markdown", mime="text/markdown",
use_container_width=True, use_container_width=True,
) )

View File

@ -7,15 +7,12 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple from typing import Literal, Dict, Tuple
from configs import (kbs_config, from configs import (kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE, EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import list_embed_models, list_online_embed_models from server.utils import list_embed_models, list_online_embed_models
import os import os
import time import time
# SENTENCE_SIZE = 100
cell_renderer = JsCode("""function(params) {if(params.value==true){return ''}else{return '×'}}""") cell_renderer = JsCode("""function(params) {if(params.value==true){return ''}else{return '×'}}""")
@ -32,7 +29,7 @@ def config_aggrid(
gb.configure_selection( gb.configure_selection(
selection_mode=selection_mode, selection_mode=selection_mode,
use_checkbox=use_checkbox, use_checkbox=use_checkbox,
# pre_selected_rows=st.session_state.get("selected_rows", [0]), pre_selected_rows=st.session_state.get("selected_rows", [0]),
) )
gb.configure_pagination( gb.configure_pagination(
enabled=True, enabled=True,
@ -59,7 +56,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
try: try:
kb_list = {x["kb_name"]: x for x in get_kb_details()} kb_list = {x["kb_name"]: x for x in get_kb_details()}
except Exception as e: except Exception as e:
st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。") st.error(
"获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
st.stop() st.stop()
kb_names = list(kb_list.keys()) kb_names = list(kb_list.keys())
@ -150,7 +148,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
[i for ls in LOADER_DICT.values() for i in ls], [i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True, accept_multiple_files=True,
) )
kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, key=None, kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None,
key=None,
help=None, on_change=None, args=None, kwargs=None) help=None, on_change=None, args=None, kwargs=None)
if kb_info != st.session_state["selected_kb_info"]: if kb_info != st.session_state["selected_kb_info"]:
@ -200,8 +199,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
doc_details = doc_details[[ doc_details = doc_details[[
"No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db", "No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db",
]] ]]
# doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×") doc_details["in_folder"] = doc_details["in_folder"].replace(True, "").replace(False, "×")
# doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×") doc_details["in_db"] = doc_details["in_db"].replace(True, "").replace(False, "×")
gb = config_aggrid( gb = config_aggrid(
doc_details, doc_details,
{ {
@ -252,7 +251,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
st.write() st.write()
# 将文件分词并加载到向量库中 # 将文件分词并加载到向量库中
if cols[1].button( if cols[1].button(
"重新添加至向量库" if selected_rows and (pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库", "重新添加至向量库" if selected_rows and (
pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库",
disabled=not file_exists(kb, selected_rows)[0], disabled=not file_exists(kb, selected_rows)[0],
use_container_width=True, use_container_width=True,
): ):
@ -285,39 +285,39 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
st.divider() st.divider()
# cols = st.columns(3) cols = st.columns(3)
# if cols[0].button( if cols[0].button(
# "依据源文件重建向量库", "依据源文件重建向量库",
# # help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。", help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。",
# use_container_width=True, use_container_width=True,
# type="primary", type="primary",
# ): ):
# with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"): with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
# empty = st.empty() empty = st.empty()
# empty.progress(0.0, "") empty.progress(0.0, "")
# for d in api.recreate_vector_store(kb, for d in api.recreate_vector_store(kb,
# chunk_size=chunk_size, chunk_size=chunk_size,
# chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
# zh_title_enhance=zh_title_enhance): zh_title_enhance=zh_title_enhance):
# if msg := check_error_msg(d): if msg := check_error_msg(d):
# st.toast(msg) st.toast(msg)
# else: else:
# empty.progress(d["finished"] / d["total"], d["msg"]) empty.progress(d["finished"] / d["total"], d["msg"])
# st.rerun() st.rerun()
# if cols[2].button( if cols[2].button(
# "删除知识库", "删除知识库",
# use_container_width=True, use_container_width=True,
# ): ):
# ret = api.delete_knowledge_base(kb) ret = api.delete_knowledge_base(kb)
# st.toast(ret.get("msg", " ")) st.toast(ret.get("msg", " "))
# time.sleep(1) time.sleep(1)
# st.rerun() st.rerun()
# with st.sidebar: with st.sidebar:
# keyword = st.text_input("查询关键字") keyword = st.text_input("查询关键字")
# top_k = st.slider("匹配条数", 1, 100, 3) top_k = st.slider("匹配条数", 1, 100, 3)
st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。") st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。")
docs = [] docs = []
@ -325,11 +325,12 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
if selected_rows: if selected_rows:
file_name = selected_rows[0]["file_name"] file_name = selected_rows[0]["file_name"]
docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name) docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name)
data = [{"seq": i+1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"), data = [
"type": x["type"], {"seq": i + 1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"),
"metadata": json.dumps(x["metadata"], ensure_ascii=False), "type": x["type"],
"to_del": "", "metadata": json.dumps(x["metadata"], ensure_ascii=False),
} for i, x in enumerate(docs)] "to_del": "",
} for i, x in enumerate(docs)]
df = pd.DataFrame(data) df = pd.DataFrame(data)
gb = GridOptionsBuilder.from_dataframe(df) gb = GridOptionsBuilder.from_dataframe(df)
@ -343,22 +344,24 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
edit_docs = AgGrid(df, gb.build()) edit_docs = AgGrid(df, gb.build())
if st.button("保存更改"): if st.button("保存更改"):
# origin_docs = {x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in docs} origin_docs = {
x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in
docs}
changed_docs = [] changed_docs = []
for index, row in edit_docs.data.iterrows(): for index, row in edit_docs.data.iterrows():
# origin_doc = origin_docs[row["id"]] origin_doc = origin_docs[row["id"]]
# if row["page_content"] != origin_doc["page_content"]: if row["page_content"] != origin_doc["page_content"]:
if row["to_del"] not in ["Y", "y", 1]: if row["to_del"] not in ["Y", "y", 1]:
changed_docs.append({ changed_docs.append({
"page_content": row["page_content"], "page_content": row["page_content"],
"type": row["type"], "type": row["type"],
"metadata": json.loads(row["metadata"]), "metadata": json.loads(row["metadata"]),
}) })
if changed_docs: if changed_docs:
if api.update_kb_docs(knowledge_base_name=selected_kb, if api.update_kb_docs(knowledge_base_name=selected_kb,
file_names=[file_name], file_names=[file_name],
docs={file_name: changed_docs}): docs={file_name: changed_docs}):
st.toast("更新文档成功") st.toast("更新文档成功")
else: else:
st.toast("更新文档失败") st.toast("更新文档失败")

View File

@ -1,7 +1,6 @@
# 该文件封装了对api.py的请求可以被不同的webui使用 # 该文件封装了对api.py的请求可以被不同的webui使用
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用 # 通过ApiRequest和AsyncApiRequest支持同步/异步调用
from typing import * from typing import *
from pathlib import Path from pathlib import Path
# 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同 # 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
@ -27,7 +26,7 @@ from io import BytesIO
from server.utils import set_httpx_config, api_address, get_httpx_client from server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint from pprint import pprint
from langchain_core._api import deprecated
set_httpx_config() set_httpx_config()
@ -36,10 +35,11 @@ class ApiRequest:
''' '''
api.py调用的封装同步模式,简化api调用方式 api.py调用的封装同步模式,简化api调用方式
''' '''
def __init__( def __init__(
self, self,
base_url: str = api_address(), base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT, timeout: float = HTTPX_DEFAULT_TIMEOUT,
): ):
self.base_url = base_url self.base_url = base_url
self.timeout = timeout self.timeout = timeout
@ -55,12 +55,12 @@ class ApiRequest:
return self._client return self._client
def get( def get(
self, self,
url: str, url: str,
params: Union[Dict, List[Tuple], bytes] = None, params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3, retry: int = 3,
stream: bool = False, stream: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0: while retry > 0:
try: try:
@ -75,13 +75,13 @@ class ApiRequest:
retry -= 1 retry -= 1
def post( def post(
self, self,
url: str, url: str,
data: Dict = None, data: Dict = None,
json: Dict = None, json: Dict = None,
retry: int = 3, retry: int = 3,
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0: while retry > 0:
try: try:
@ -97,13 +97,13 @@ class ApiRequest:
retry -= 1 retry -= 1
def delete( def delete(
self, self,
url: str, url: str,
data: Dict = None, data: Dict = None,
json: Dict = None, json: Dict = None,
retry: int = 3, retry: int = 3,
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0: while retry > 0:
try: try:
@ -118,24 +118,25 @@ class ApiRequest:
retry -= 1 retry -= 1
def _httpx_stream2generator( def _httpx_stream2generator(
self, self,
response: contextlib._GeneratorContextManager, response: contextlib._GeneratorContextManager,
as_json: bool = False, as_json: bool = False,
): ):
''' '''
将httpx.stream返回的GeneratorContextManager转化为普通生成器 将httpx.stream返回的GeneratorContextManager转化为普通生成器
''' '''
async def ret_async(response, as_json): async def ret_async(response, as_json):
try: try:
async with response as r: async with response as r:
async for chunk in r.aiter_text(None): async for chunk in r.aiter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end if not chunk: # fastchat api yield empty bytes on start and end
continue continue
if as_json: if as_json:
try: try:
if chunk.startswith("data: "): if chunk.startswith("data: "):
data = json.loads(chunk[6:-2]) data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line elif chunk.startswith(":"): # skip sse comment line
continue continue
else: else:
data = json.loads(chunk) data = json.loads(chunk)
@ -143,7 +144,7 @@ class ApiRequest:
except Exception as e: except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}" msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
else: else:
# print(chunk, end="", flush=True) # print(chunk, end="", flush=True)
yield chunk yield chunk
@ -158,20 +159,20 @@ class ApiRequest:
except Exception as e: except Exception as e:
msg = f"API通信遇到错误{e}" msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg} yield {"code": 500, "msg": msg}
def ret_sync(response, as_json): def ret_sync(response, as_json):
try: try:
with response as r: with response as r:
for chunk in r.iter_text(None): for chunk in r.iter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end if not chunk: # fastchat api yield empty bytes on start and end
continue continue
if as_json: if as_json:
try: try:
if chunk.startswith("data: "): if chunk.startswith("data: "):
data = json.loads(chunk[6:-2]) data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line elif chunk.startswith(":"): # skip sse comment line
continue continue
else: else:
data = json.loads(chunk) data = json.loads(chunk)
@ -179,7 +180,7 @@ class ApiRequest:
except Exception as e: except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}" msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
else: else:
# print(chunk, end="", flush=True) # print(chunk, end="", flush=True)
yield chunk yield chunk
@ -194,7 +195,7 @@ class ApiRequest:
except Exception as e: except Exception as e:
msg = f"API通信遇到错误{e}" msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg} yield {"code": 500, "msg": msg}
if self._use_async: if self._use_async:
@ -203,16 +204,17 @@ class ApiRequest:
return ret_sync(response, as_json) return ret_sync(response, as_json)
def _get_response_value( def _get_response_value(
self, self,
response: httpx.Response, response: httpx.Response,
as_json: bool = False, as_json: bool = False,
value_func: Callable = None, value_func: Callable = None,
): ):
''' '''
转换同步或异步请求返回的响应 转换同步或异步请求返回的响应
`as_json`: 返回json `as_json`: 返回json
`value_func`: 用户可以自定义返回值该函数接受response或json `value_func`: 用户可以自定义返回值该函数接受response或json
''' '''
def to_json(r): def to_json(r):
try: try:
return r.json() return r.json()
@ -220,7 +222,7 @@ class ApiRequest:
msg = "API未能返回正确的JSON。" + str(e) msg = "API未能返回正确的JSON。" + str(e)
if log_verbose: if log_verbose:
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg, "data": None} return {"code": 500, "msg": msg, "data": None}
if value_func is None: if value_func is None:
@ -250,10 +252,10 @@ class ApiRequest:
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"]) return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
def get_prompt_template( def get_prompt_template(
self, self,
type: str = "llm_chat", type: str = "llm_chat",
name: str = "default", name: str = "default",
**kwargs, **kwargs,
) -> str: ) -> str:
data = { data = {
"type": type, "type": type,
@ -297,15 +299,19 @@ class ApiRequest:
response = self.post("/chat/chat", json=data, stream=True, **kwargs) response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0")
def agent_chat( def agent_chat(
self, self,
query: str, query: str,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODELS[0], model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = None, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
对应api.py/chat/agent_chat 接口 对应api.py/chat/agent_chat 接口
@ -327,17 +333,17 @@ class ApiRequest:
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
def knowledge_base_chat( def knowledge_base_chat(
self, self,
query: str, query: str,
knowledge_base_name: str, knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODELS[0], model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = None, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
对应api.py/chat/knowledge_base_chat接口 对应api.py/chat/knowledge_base_chat接口
@ -366,28 +372,29 @@ class ApiRequest:
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
def upload_temp_docs( def upload_temp_docs(
self, self,
files: List[Union[str, Path, bytes]], files: List[Union[str, Path, bytes]],
knowledge_id: str = None, knowledge_id: str = None,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
): ):
''' '''
对应api.py/knowledge_base/upload_tmep_docs接口 对应api.py/knowledge_base/upload_tmep_docs接口
''' '''
def convert_file(file, filename=None): def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes if isinstance(file, bytes): # raw bytes
file = BytesIO(file) file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object elif hasattr(file, "read"): # a file io like object
filename = filename or file.name filename = filename or file.name
else: # a local path else: # a local path
file = Path(file).absolute().open("rb") file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1] filename = filename or os.path.split(file.name)[-1]
return filename, file return filename, file
files = [convert_file(file) for file in files] files = [convert_file(file) for file in files]
data={ data = {
"knowledge_id": knowledge_id, "knowledge_id": knowledge_id,
"chunk_size": chunk_size, "chunk_size": chunk_size,
"chunk_overlap": chunk_overlap, "chunk_overlap": chunk_overlap,
@ -402,17 +409,17 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def file_chat( def file_chat(
self, self,
query: str, query: str,
knowledge_id: str, knowledge_id: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODELS[0], model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = None, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
对应api.py/chat/file_chat接口 对应api.py/chat/file_chat接口
@ -430,9 +437,6 @@ class ApiRequest:
"prompt_name": prompt_name, "prompt_name": prompt_name,
} }
# print(f"received input message:")
# pprint(data)
response = self.post( response = self.post(
"/chat/file_chat", "/chat/file_chat",
json=data, json=data,
@ -440,18 +444,23 @@ class ApiRequest:
) )
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def search_engine_chat( def search_engine_chat(
self, self,
query: str, query: str,
search_engine_name: str, search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K, top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODELS[0], model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = None, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
split_result: bool = False, split_result: bool = False,
): ):
''' '''
对应api.py/chat/search_engine_chat接口 对应api.py/chat/search_engine_chat接口
@ -482,7 +491,7 @@ class ApiRequest:
# 知识库相关操作 # 知识库相关操作
def list_knowledge_bases( def list_knowledge_bases(
self, self,
): ):
''' '''
对应api.py/knowledge_base/list_knowledge_bases接口 对应api.py/knowledge_base/list_knowledge_bases接口
@ -493,10 +502,10 @@ class ApiRequest:
value_func=lambda r: r.get("data", [])) value_func=lambda r: r.get("data", []))
def create_knowledge_base( def create_knowledge_base(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
vector_store_type: str = DEFAULT_VS_TYPE, vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
): ):
''' '''
对应api.py/knowledge_base/create_knowledge_base接口 对应api.py/knowledge_base/create_knowledge_base接口
@ -514,8 +523,8 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def delete_knowledge_base( def delete_knowledge_base(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
): ):
''' '''
对应api.py/knowledge_base/delete_knowledge_base接口 对应api.py/knowledge_base/delete_knowledge_base接口
@ -527,8 +536,8 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def list_kb_docs( def list_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
): ):
''' '''
对应api.py/knowledge_base/list_files接口 对应api.py/knowledge_base/list_files接口
@ -542,13 +551,13 @@ class ApiRequest:
value_func=lambda r: r.get("data", [])) value_func=lambda r: r.get("data", []))
def search_kb_docs( def search_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
query: str = "", query: str = "",
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD, score_threshold: int = SCORE_THRESHOLD,
file_name: str = "", file_name: str = "",
metadata: dict = {}, metadata: dict = {},
) -> List: ) -> List:
''' '''
对应api.py/knowledge_base/search_docs接口 对应api.py/knowledge_base/search_docs接口
@ -569,9 +578,9 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def update_docs_by_id( def update_docs_by_id(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
docs: Dict[str, Dict], docs: Dict[str, Dict],
) -> bool: ) -> bool:
''' '''
对应api.py/knowledge_base/update_docs_by_id接口 对应api.py/knowledge_base/update_docs_by_id接口
@ -587,32 +596,33 @@ class ApiRequest:
return self._get_response_value(response) return self._get_response_value(response)
def upload_kb_docs( def upload_kb_docs(
self, self,
files: List[Union[str, Path, bytes]], files: List[Union[str, Path, bytes]],
knowledge_base_name: str, knowledge_base_name: str,
override: bool = False, override: bool = False,
to_vector_store: bool = True, to_vector_store: bool = True,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {}, docs: Dict = {},
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
): ):
''' '''
对应api.py/knowledge_base/upload_docs接口 对应api.py/knowledge_base/upload_docs接口
''' '''
def convert_file(file, filename=None): def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes if isinstance(file, bytes): # raw bytes
file = BytesIO(file) file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object elif hasattr(file, "read"): # a file io like object
filename = filename or file.name filename = filename or file.name
else: # a local path else: # a local path
file = Path(file).absolute().open("rb") file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1] filename = filename or os.path.split(file.name)[-1]
return filename, file return filename, file
files = [convert_file(file) for file in files] files = [convert_file(file) for file in files]
data={ data = {
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"override": override, "override": override,
"to_vector_store": to_vector_store, "to_vector_store": to_vector_store,
@ -633,11 +643,11 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def delete_kb_docs( def delete_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
file_names: List[str], file_names: List[str],
delete_content: bool = False, delete_content: bool = False,
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
): ):
''' '''
对应api.py/knowledge_base/delete_docs接口 对应api.py/knowledge_base/delete_docs接口
@ -655,8 +665,7 @@ class ApiRequest:
) )
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def update_kb_info(self, knowledge_base_name, kb_info):
def update_kb_info(self,knowledge_base_name,kb_info):
''' '''
对应api.py/knowledge_base/update_info接口 对应api.py/knowledge_base/update_info接口
''' '''
@ -672,15 +681,15 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def update_kb_docs( def update_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
file_names: List[str], file_names: List[str],
override_custom_docs: bool = False, override_custom_docs: bool = False,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {}, docs: Dict = {},
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
): ):
''' '''
对应api.py/knowledge_base/update_docs接口 对应api.py/knowledge_base/update_docs接口
@ -706,14 +715,14 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def recreate_vector_store( def recreate_vector_store(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
allow_empty_kb: bool = True, allow_empty_kb: bool = True,
vs_type: str = DEFAULT_VS_TYPE, vs_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
): ):
''' '''
对应api.py/knowledge_base/recreate_vector_store接口 对应api.py/knowledge_base/recreate_vector_store接口
@ -738,8 +747,8 @@ class ApiRequest:
# LLM模型相关操作 # LLM模型相关操作
def list_running_models( def list_running_models(
self, self,
controller_address: str = None, controller_address: str = None,
): ):
''' '''
获取Fastchat中正运行的模型列表 获取Fastchat中正运行的模型列表
@ -755,8 +764,7 @@ class ApiRequest:
"/llm_model/list_running_models", "/llm_model/list_running_models",
json=data, json=data,
) )
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", [])) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", []))
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]: def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
''' '''
@ -764,6 +772,7 @@ class ApiRequest:
local_first=True 优先返回运行中的本地模型否则优先按LLM_MODELS配置顺序返回 local_first=True 优先返回运行中的本地模型否则优先按LLM_MODELS配置顺序返回
返回类型为model_name, is_local_model 返回类型为model_name, is_local_model
''' '''
def ret_sync(): def ret_sync():
running_models = self.list_running_models() running_models = self.list_running_models()
if not running_models: if not running_models:
@ -780,7 +789,7 @@ class ApiRequest:
model = m model = m
break break
if not model: # LLM_MODELS中配置的模型都不在running_models里 if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0] model = list(running_models)[0]
is_local = not running_models[model].get("online_api") is_local = not running_models[model].get("online_api")
return model, is_local return model, is_local
@ -801,7 +810,7 @@ class ApiRequest:
model = m model = m
break break
if not model: # LLM_MODELS中配置的模型都不在running_models里 if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0] model = list(running_models)[0]
is_local = not running_models[model].get("online_api") is_local = not running_models[model].get("online_api")
return model, is_local return model, is_local
@ -812,8 +821,8 @@ class ApiRequest:
return ret_sync() return ret_sync()
def list_config_models( def list_config_models(
self, self,
types: List[str] = ["local", "online"], types: List[str] = ["local", "online"],
) -> Dict[str, Dict]: ) -> Dict[str, Dict]:
''' '''
获取服务器configs中配置的模型列表返回形式为{"type": {model_name: config}, ...} 获取服务器configs中配置的模型列表返回形式为{"type": {model_name: config}, ...}
@ -825,23 +834,23 @@ class ApiRequest:
"/llm_model/list_config_models", "/llm_model/list_config_models",
json=data, json=data,
) )
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def get_model_config( def get_model_config(
self, self,
model_name: str = None, model_name: str = None,
) -> Dict: ) -> Dict:
''' '''
获取服务器上模型配置 获取服务器上模型配置
''' '''
data={ data = {
"model_name": model_name, "model_name": model_name,
} }
response = self.post( response = self.post(
"/llm_model/get_model_config", "/llm_model/get_model_config",
json=data, json=data,
) )
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def list_search_engines(self) -> List[str]: def list_search_engines(self) -> List[str]:
''' '''
@ -850,12 +859,12 @@ class ApiRequest:
response = self.post( response = self.post(
"/server/list_search_engines", "/server/list_search_engines",
) )
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def stop_llm_model( def stop_llm_model(
self, self,
model_name: str, model_name: str,
controller_address: str = None, controller_address: str = None,
): ):
''' '''
停止某个LLM模型 停止某个LLM模型
@ -873,10 +882,10 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def change_llm_model( def change_llm_model(
self, self,
model_name: str, model_name: str,
new_model_name: str, new_model_name: str,
controller_address: str = None, controller_address: str = None,
): ):
''' '''
向fastchat controller请求切换LLM模型 向fastchat controller请求切换LLM模型
@ -959,10 +968,10 @@ class ApiRequest:
return ret_sync() return ret_sync()
def embed_texts( def embed_texts(
self, self,
texts: List[str], texts: List[str],
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
to_query: bool = False, to_query: bool = False,
) -> List[List[float]]: ) -> List[List[float]]:
''' '''
对文本进行向量化可选模型包括本地 embed_models 和支持 embeddings 的在线模型 对文本进行向量化可选模型包括本地 embed_models 和支持 embeddings 的在线模型
@ -979,10 +988,10 @@ class ApiRequest:
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
def chat_feedback( def chat_feedback(
self, self,
message_id: str, message_id: str,
score: int, score: int,
reason: str = "", reason: str = "",
) -> int: ) -> int:
''' '''
反馈对话评价 反馈对话评价
@ -1019,9 +1028,9 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
return error message if error occured when requests API return error message if error occured when requests API
''' '''
if (isinstance(data, dict) if (isinstance(data, dict)
and key in data and key in data
and "code" in data and "code" in data
and data["code"] == 200): and data["code"] == 200):
return data[key] return data[key]
return "" return ""