mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
publish 0.2.10 (#2797)
新功能: - 优化 PDF 文件的 OCR,过滤无意义的小图片 by @liunux4odoo #2525 - 支持 Gemini 在线模型 by @yhfgyyf #2630 - 支持 GLM4 在线模型 by @zRzRzRzRzRzRzR - elasticsearch更新https连接 by @xldistance #2390 - 增强对PPT、DOC知识库文件的OCR识别 by @596192804 #2013 - 更新 Agent 对话功能 by @zRzRzRzRzRzRzR - 每次创建对象时从连接池获取连接,避免每次执行方法时都新建连接 by @Lijia0 #2480 - 实现 ChatOpenAI 判断token有没有超过模型的context上下文长度 by @glide-the - 更新运行数据库报错和项目里程碑 by @zRzRzRzRzRzRzR #2659 - 更新配置文件/文档/依赖 by @imClumsyPanda @zRzRzRzRzRzRzR - 添加日文版 readme by @eltociear #2787 修复: - langchain 更新后,PGVector 向量库连接错误 by @HALIndex #2591 - Minimax's model worker 错误 by @xyhshen - ES库无法向量检索.添加mappings创建向量索引 by MSZheng20 #2688
This commit is contained in:
parent
ee6a28b565
commit
9c525b7fa5
83
README.md
83
README.md
@ -1,6 +1,5 @@
|
|||||||

|

|
||||||
|
|
||||||
|
|
||||||
🌍 [READ THIS IN ENGLISH](README_en.md)
|
🌍 [READ THIS IN ENGLISH](README_en.md)
|
||||||
🌍 [日本語で読む](README_ja.md)
|
🌍 [日本語で読む](README_ja.md)
|
||||||
|
|
||||||
@ -8,6 +7,8 @@
|
|||||||
|
|
||||||
基于 ChatGLM 等大语言模型与 Langchain 等应用框架实现,开源、可离线部署的检索增强生成(RAG)大模型知识库项目。
|
基于 ChatGLM 等大语言模型与 Langchain 等应用框架实现,开源、可离线部署的检索增强生成(RAG)大模型知识库项目。
|
||||||
|
|
||||||
|
⚠️`0.2.10`将会是`0.2.x`系列的最后一个版本,`0.2.x`系列版本将会停止更新和技术支持,全力研发具有更强应用性的 `Langchain-Chatchat 0.3.x`。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
@ -15,23 +16,31 @@
|
|||||||
* [介绍](README.md#介绍)
|
* [介绍](README.md#介绍)
|
||||||
* [解决的痛点](README.md#解决的痛点)
|
* [解决的痛点](README.md#解决的痛点)
|
||||||
* [快速上手](README.md#快速上手)
|
* [快速上手](README.md#快速上手)
|
||||||
* [1. 环境配置](README.md#1-环境配置)
|
* [1. 环境配置](README.md#1-环境配置)
|
||||||
* [2. 模型下载](README.md#2-模型下载)
|
* [2. 模型下载](README.md#2-模型下载)
|
||||||
* [3. 初始化知识库和配置文件](README.md#3-初始化知识库和配置文件)
|
* [3. 初始化知识库和配置文件](README.md#3-初始化知识库和配置文件)
|
||||||
* [4. 一键启动](README.md#4-一键启动)
|
* [4. 一键启动](README.md#4-一键启动)
|
||||||
* [5. 启动界面示例](README.md#5-启动界面示例)
|
* [5. 启动界面示例](README.md#5-启动界面示例)
|
||||||
* [联系我们](README.md#联系我们)
|
* [联系我们](README.md#联系我们)
|
||||||
|
|
||||||
|
|
||||||
## 介绍
|
## 介绍
|
||||||
|
|
||||||
🤖️ 一种利用 [langchain](https://github.com/hwchase17/langchain) 思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。
|
🤖️ 一种利用 [langchain](https://github.com/langchain-ai/langchain)
|
||||||
|
思想实现的基于本地知识库的问答应用,目标期望建立一套对中文场景与开源模型支持友好、可离线运行的知识库问答解决方案。
|
||||||
|
|
||||||
💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai) 和 [AlexZhangji](https://github.com/AlexZhangji) 创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216) 启发,建立了全流程可使用开源模型实现的本地知识库问答应用。本项目的最新版本中通过使用 [FastChat](https://github.com/lm-sys/FastChat) 接入 Vicuna, Alpaca, LLaMA, Koala, RWKV 等模型,依托于 [langchain](https://github.com/langchain-ai/langchain) 框架支持通过基于 [FastAPI](https://github.com/tiangolo/fastapi) 提供的 API 调用服务,或使用基于 [Streamlit](https://github.com/streamlit/streamlit) 的 WebUI 进行操作。
|
💡 受 [GanymedeNil](https://github.com/GanymedeNil) 的项目 [document.ai](https://github.com/GanymedeNil/document.ai)
|
||||||
|
和 [AlexZhangji](https://github.com/AlexZhangji)
|
||||||
|
创建的 [ChatGLM-6B Pull Request](https://github.com/THUDM/ChatGLM-6B/pull/216)
|
||||||
|
启发,建立了全流程可使用开源模型实现的本地知识库问答应用。本项目的最新版本中通过使用 [FastChat](https://github.com/lm-sys/FastChat)
|
||||||
|
接入 Vicuna, Alpaca, LLaMA, Koala, RWKV 等模型,依托于 [langchain](https://github.com/langchain-ai/langchain)
|
||||||
|
框架支持通过基于 [FastAPI](https://github.com/tiangolo/fastapi) 提供的 API
|
||||||
|
调用服务,或使用基于 [Streamlit](https://github.com/streamlit/streamlit) 的 WebUI 进行操作。
|
||||||
|
|
||||||
✅ 依托于本项目支持的开源 LLM 与 Embedding 模型,本项目可实现全部使用**开源**模型**离线私有部署**。与此同时,本项目也支持 OpenAI GPT API 的调用,并将在后续持续扩充对各类模型及模型 API 的接入。
|
✅ 依托于本项目支持的开源 LLM 与 Embedding 模型,本项目可实现全部使用**开源**模型**离线私有部署**。与此同时,本项目也支持
|
||||||
|
OpenAI GPT API 的调用,并将在后续持续扩充对各类模型及模型 API 的接入。
|
||||||
|
|
||||||
⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 -> 在文本向量中匹配出与问句向量最相似的 `top k`个 -> 匹配出的文本作为上下文和问题一起添加到 `prompt`中 -> 提交给 `LLM`生成回答。
|
⛓️ 本项目实现原理如下图所示,过程包括加载文件 -> 读取文本 -> 文本分割 -> 文本向量化 -> 问句向量化 ->
|
||||||
|
在文本向量中匹配出与问句向量最相似的 `top k`个 -> 匹配出的文本作为上下文和问题一起添加到 `prompt`中 -> 提交给 `LLM`生成回答。
|
||||||
|
|
||||||
📺 [原理介绍视频](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514)
|
📺 [原理介绍视频](https://www.bilibili.com/video/BV13M4y1e7cN/?share_source=copy_web&vd_source=e6c5aafe684f30fbe41925d61ca6d514)
|
||||||
|
|
||||||
@ -43,7 +52,8 @@
|
|||||||
|
|
||||||
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
|
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
|
||||||
|
|
||||||
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v13` 版本所使用代码已更新至本项目 `v0.2.9` 版本。
|
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v13`
|
||||||
|
版本所使用代码已更新至本项目 `v0.2.9` 版本。
|
||||||
|
|
||||||
🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.6) 已经更新到 ```0.2.7``` 版本。
|
🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.6) 已经更新到 ```0.2.7``` 版本。
|
||||||
|
|
||||||
@ -53,7 +63,10 @@
|
|||||||
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7
|
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7
|
||||||
```
|
```
|
||||||
|
|
||||||
🧩 本项目有一个非常完整的[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) , README只是一个简单的介绍,__仅仅是入门教程,能够基础运行__。 如果你想要更深入的了解本项目,或者想对本项目做出贡献。请移步 [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) 界面
|
🧩 本项目有一个非常完整的[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/) , README只是一个简单的介绍,_
|
||||||
|
_仅仅是入门教程,能够基础运行__。
|
||||||
|
如果你想要更深入的了解本项目,或者想对本项目做出贡献。请移步 [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
|
||||||
|
界面
|
||||||
|
|
||||||
## 解决的痛点
|
## 解决的痛点
|
||||||
|
|
||||||
@ -63,17 +76,19 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
|||||||
我们支持市面上主流的本地大语言模型和Embedding模型,支持开源的本地向量数据库。
|
我们支持市面上主流的本地大语言模型和Embedding模型,支持开源的本地向量数据库。
|
||||||
支持列表详见[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
|
支持列表详见[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
|
||||||
|
|
||||||
|
|
||||||
## 快速上手
|
## 快速上手
|
||||||
|
|
||||||
### 1. 环境配置
|
### 1. 环境配置
|
||||||
|
|
||||||
+ 首先,确保你的机器安装了 Python 3.8 - 3.11
|
+ 首先,确保你的机器安装了 Python 3.8 - 3.11 (我们强烈推荐使用 Python3.11)。
|
||||||
|
|
||||||
```
|
```
|
||||||
$ python --version
|
$ python --version
|
||||||
Python 3.11.7
|
Python 3.11.7
|
||||||
```
|
```
|
||||||
|
|
||||||
接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖
|
接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|
||||||
# 拉取仓库
|
# 拉取仓库
|
||||||
@ -89,33 +104,44 @@ $ pip install -r requirements_webui.txt
|
|||||||
|
|
||||||
# 默认依赖包括基本运行环境(FAISS向量库)。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
|
# 默认依赖包括基本运行环境(FAISS向量库)。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
|
||||||
```
|
```
|
||||||
请注意,LangChain-Chatchat `0.2.x` 系列是针对 Langchain `0.0.x` 系列版本的,如果你使用的是 Langchain `0.1.x` 系列版本,需要降级。
|
|
||||||
|
请注意,LangChain-Chatchat `0.2.x` 系列是针对 Langchain `0.0.x` 系列版本的,如果你使用的是 Langchain `0.1.x`
|
||||||
|
系列版本,需要降级您的`Langchain`版本。
|
||||||
|
|
||||||
### 2, 模型下载
|
### 2, 模型下载
|
||||||
|
|
||||||
如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding 模型可以从 [HuggingFace](https://huggingface.co/models) 下载。
|
如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding
|
||||||
|
模型可以从 [HuggingFace](https://huggingface.co/models) 下载。
|
||||||
|
|
||||||
以本项目中默认使用的 LLM 模型 [THUDM/ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b) 与 Embedding 模型 [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) 为例:
|
以本项目中默认使用的 LLM 模型 [THUDM/ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b) 与 Embedding
|
||||||
|
模型 [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) 为例:
|
||||||
|
|
||||||
下载模型需要先[安装 Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage),然后运行
|
下载模型需要先[安装 Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage)
|
||||||
|
,然后运行
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
$ git lfs install
|
$ git lfs install
|
||||||
$ git clone https://huggingface.co/THUDM/chatglm3-6b
|
$ git clone https://huggingface.co/THUDM/chatglm3-6b
|
||||||
$ git clone https://huggingface.co/BAAI/bge-large-zh
|
$ git clone https://huggingface.co/BAAI/bge-large-zh
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. 初始化知识库和配置文件
|
### 3. 初始化知识库和配置文件
|
||||||
|
|
||||||
按照下列方式初始化自己的知识库和简单的复制配置文件
|
按照下列方式初始化自己的知识库和简单的复制配置文件
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
$ python copy_config_example.py
|
$ python copy_config_example.py
|
||||||
$ python init_database.py --recreate-vs
|
$ python init_database.py --recreate-vs
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. 一键启动
|
### 4. 一键启动
|
||||||
|
|
||||||
按照以下命令启动项目
|
按照以下命令启动项目
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
$ python startup.py -a
|
$ python startup.py -a
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5. 启动界面示例
|
### 5. 启动界面示例
|
||||||
|
|
||||||
如果正常启动,你将能看到以下界面
|
如果正常启动,你将能看到以下界面
|
||||||
@ -134,19 +160,32 @@ $ python startup.py -a
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|
||||||
### 注意
|
### 注意
|
||||||
|
|
||||||
以上方式只是为了快速上手,如果需要更多的功能和自定义启动方式 ,请参考[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
|
以上方式只是为了快速上手,如果需要更多的功能和自定义启动方式
|
||||||
|
,请参考[Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/)
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 项目里程碑
|
## 项目里程碑
|
||||||
|
|
||||||
|
+ `2023年4月`: `Langchain-ChatGLM 0.1.0` 发布,支持基于 ChatGLM-6B 模型的本地知识库问答。
|
||||||
|
+ `2023年8月`: `Langchain-ChatGLM` 改名为 `Langchain-Chatchat`,`0.2.0` 发布,使用 `fastchat` 作为模型加载方案,支持更多的模型和数据库。
|
||||||
|
+ `2023年10月`: `Langchain-Chatchat 0.2.5` 发布,推出 Agent 内容,开源项目在`Founder Park & Zhipu AI & Zilliz`
|
||||||
|
举办的黑客马拉松获得三等奖。
|
||||||
|
+ `2023年12月`: `Langchain-Chatchat` 开源项目获得超过 **20K** stars.
|
||||||
|
+ `2024年1月`: `LangChain 0.1.x` 推出,`Langchain-Chatchat 0.2.x` 发布稳定版本`0.2.10`
|
||||||
|
后将停止更新和技术支持,全力研发具有更强应用性的 `Langchain-Chatchat 0.3.x`。
|
||||||
|
|
||||||
|
+ 🔥 让我们一起期待未来 Chatchat 的故事 ···
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 联系我们
|
## 联系我们
|
||||||
|
|
||||||
### Telegram
|
### Telegram
|
||||||
|
|
||||||
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
||||||
|
|
||||||
### 项目交流群
|
### 项目交流群
|
||||||
@ -158,4 +197,4 @@ $ python startup.py -a
|
|||||||
|
|
||||||
<img src="img/official_wechat_mp_account.png" alt="二维码" width="300" />
|
<img src="img/official_wechat_mp_account.png" alt="二维码" width="300" />
|
||||||
|
|
||||||
🎉 Langchain-Chatchat 项目官方公众号,欢迎扫码关注。
|
🎉 Langchain-Chatchat 项目官方公众号,欢迎扫码关注。
|
||||||
38
README_en.md
38
README_en.md
@ -8,6 +8,10 @@
|
|||||||
A LLM application aims to implement knowledge and search engine based QA based on Langchain and open-source or remote
|
A LLM application aims to implement knowledge and search engine based QA based on Langchain and open-source or remote
|
||||||
LLM API.
|
LLM API.
|
||||||
|
|
||||||
|
⚠️`0.2.10` will be the last version of the `0.2.x` series. The `0.2.x` series will stop updating and technical support,
|
||||||
|
and strive to develop `Langchain-Chachat 0.3.x with stronger applicability. `.
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
@ -25,7 +29,8 @@ LLM API.
|
|||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
🤖️ A Q&A application based on local knowledge base implemented using the idea
|
🤖️ A Q&A application based on local knowledge base implemented using the idea
|
||||||
of [langchain](https://github.com/hwchase17/langchain). The goal is to build a KBQA(Knowledge based Q&A) solution that
|
of [langchain](https://github.com/langchain-ai/langchain). The goal is to build a KBQA(Knowledge based Q&A) solution
|
||||||
|
that
|
||||||
is friendly to Chinese scenarios and open source models and can run both offline and online.
|
is friendly to Chinese scenarios and open source models and can run both offline and online.
|
||||||
|
|
||||||
💡 Inspired by [document.ai](https://github.com/GanymedeNil/document.ai)
|
💡 Inspired by [document.ai](https://github.com/GanymedeNil/document.ai)
|
||||||
@ -56,10 +61,9 @@ The main process analysis from the aspect of document process:
|
|||||||
🚩 The training or fine-tuning are not involved in the project, but still, one always can improve performance by do
|
🚩 The training or fine-tuning are not involved in the project, but still, one always can improve performance by do
|
||||||
these.
|
these.
|
||||||
|
|
||||||
🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v13 the codes are update
|
🌐 [AutoDL image](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) is supported, and in v13 the codes are update to v0.2.9.
|
||||||
to v0.2.9.
|
|
||||||
|
|
||||||
🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7)
|
🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7) is supported to 0.2.7
|
||||||
|
|
||||||
## Pain Points Addressed
|
## Pain Points Addressed
|
||||||
|
|
||||||
@ -99,7 +103,9 @@ $ pip install -r requirements_webui.txt
|
|||||||
|
|
||||||
# 默认依赖包括基本运行环境(FAISS向量库)。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
|
# 默认依赖包括基本运行环境(FAISS向量库)。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
|
||||||
```
|
```
|
||||||
Please note that the LangChain-Chachat `0.2.x` series is for the Langchain `0.0.x` series version. If you are using the Langchain `0.1.x` series version, you need to downgrade.
|
|
||||||
|
Please note that the LangChain-Chachat `0.2.x` series is for the Langchain `0.0.x` series version. If you are using the
|
||||||
|
Langchain `0.1.x` series version, you need to downgrade.
|
||||||
|
|
||||||
### Model Download
|
### Model Download
|
||||||
|
|
||||||
@ -159,15 +165,33 @@ please refer to the [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Project Milestones
|
||||||
|
|
||||||
|
+ `April 2023`: `Langchain-ChatGLM 0.1.0` released, supporting local knowledge base question and answer based on the
|
||||||
|
ChatGLM-6B model.
|
||||||
|
+ `August 2023`: `Langchain-ChatGLM` was renamed to `Langchain-Chatchat`, `0.2.0` was released, using `fastchat` as the
|
||||||
|
model loading solution, supporting more models and databases.
|
||||||
|
+ `October 2023`: `Langchain-Chachat 0.2.5` was released, Agent content was launched, and the open source project won
|
||||||
|
the third prize in the hackathon held by `Founder Park & Zhipu AI & Zilliz`.
|
||||||
|
+ `December 2023`: `Langchain-Chachat` open source project received more than **20K** stars.
|
||||||
|
+ `January 2024`: `LangChain 0.1.x` is launched, `Langchain-Chachat 0.2.x` is released. After the stable
|
||||||
|
version `0.2.10` is released, updates and technical support will be stopped, and all efforts will be made to
|
||||||
|
develop `Langchain with stronger applicability -Chat 0.3.x`.
|
||||||
|
|
||||||
|
|
||||||
|
+ 🔥 Let’s look forward to the future Chatchat stories together···
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Contact Us
|
## Contact Us
|
||||||
|
|
||||||
### Telegram
|
### Telegram
|
||||||
|
|
||||||
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
[](https://t.me/+RjliQ3jnJ1YyN2E9)
|
||||||
|
|
||||||
### WeChat Group、
|
### WeChat Group
|
||||||
|
|
||||||
<img src="img/qr_code_67.jpg" alt="二维码" width="300" height="300" />
|
<img src="img/qr_code_87.jpg" alt="二维码" width="300" height="300" />
|
||||||
|
|
||||||
### WeChat Official Account
|
### WeChat Official Account
|
||||||
|
|
||||||
|
|||||||
@ -5,4 +5,4 @@ from .server_config import *
|
|||||||
from .prompt_config import *
|
from .prompt_config import *
|
||||||
|
|
||||||
|
|
||||||
VERSION = "v0.2.9"
|
VERSION = "v0.2.10"
|
||||||
|
|||||||
@ -21,10 +21,9 @@ OVERLAP_SIZE = 50
|
|||||||
# 知识库匹配向量数量
|
# 知识库匹配向量数量
|
||||||
VECTOR_SEARCH_TOP_K = 3
|
VECTOR_SEARCH_TOP_K = 3
|
||||||
|
|
||||||
# 知识库匹配的距离阈值,取值范围在0-1之间,SCORE越小,距离越小从而相关度越高,
|
# 知识库匹配的距离阈值,一般取值范围在0-1之间,SCORE越小,距离越小从而相关度越高。
|
||||||
# 取到1相当于不筛选,实测bge-large的距离得分大部分在0.01-0.7之间,
|
# 但有用户报告遇到过匹配分值超过1的情况,为了兼容性默认设为1,在WEBUI中调整范围为0-2
|
||||||
# 相似文本的得分最高在0.55左右,因此建议针对bge设置得分为0.6
|
SCORE_THRESHOLD = 1.0
|
||||||
SCORE_THRESHOLD = 0.6
|
|
||||||
|
|
||||||
# 默认搜索引擎。可选:bing, duckduckgo, metaphor
|
# 默认搜索引擎。可选:bing, duckduckgo, metaphor
|
||||||
DEFAULT_SEARCH_ENGINE = "duckduckgo"
|
DEFAULT_SEARCH_ENGINE = "duckduckgo"
|
||||||
@ -49,12 +48,17 @@ BING_SUBSCRIPTION_KEY = ""
|
|||||||
# metaphor搜索需要KEY
|
# metaphor搜索需要KEY
|
||||||
METAPHOR_API_KEY = ""
|
METAPHOR_API_KEY = ""
|
||||||
|
|
||||||
|
# 心知天气 API KEY,用于天气Agent。申请:https://www.seniverse.com/
|
||||||
|
SENIVERSE_API_KEY = ""
|
||||||
|
|
||||||
# 是否开启中文标题加强,以及标题增强的相关配置
|
# 是否开启中文标题加强,以及标题增强的相关配置
|
||||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||||
ZH_TITLE_ENHANCE = False
|
ZH_TITLE_ENHANCE = False
|
||||||
|
|
||||||
|
# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
|
||||||
|
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
|
||||||
|
PDF_OCR_THRESHOLD = (0.6, 0.6)
|
||||||
|
|
||||||
# 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。
|
# 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。
|
||||||
KB_INFO = {
|
KB_INFO = {
|
||||||
|
|||||||
@ -6,9 +6,9 @@ import os
|
|||||||
MODEL_ROOT_PATH = ""
|
MODEL_ROOT_PATH = ""
|
||||||
|
|
||||||
# 选用的 Embedding 名称
|
# 选用的 Embedding 名称
|
||||||
EMBEDDING_MODEL = "bge-large-zh"
|
EMBEDDING_MODEL = "bge-large-zh-v1.5"
|
||||||
|
|
||||||
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
# Embedding 模型运行设备。设为 "auto" 会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
|
||||||
EMBEDDING_DEVICE = "auto"
|
EMBEDDING_DEVICE = "auto"
|
||||||
|
|
||||||
# 选用的reranker模型
|
# 选用的reranker模型
|
||||||
@ -26,44 +26,33 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
|
|||||||
# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
|
# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
|
||||||
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
|
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
|
||||||
|
|
||||||
# chatglm3-6b输出角色标签<|user|>及自问自答的问题详见项目wiki->常见问题->Q20.
|
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"]
|
||||||
|
|
||||||
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",
|
|
||||||
|
|
||||||
# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0])
|
|
||||||
Agent_MODEL = None
|
Agent_MODEL = None
|
||||||
|
|
||||||
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
|
||||||
LLM_DEVICE = "auto"
|
LLM_DEVICE = "auto"
|
||||||
|
|
||||||
# 历史对话轮数
|
|
||||||
HISTORY_LEN = 3
|
HISTORY_LEN = 3
|
||||||
|
|
||||||
# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度
|
MAX_TOKENS = 2048
|
||||||
MAX_TOKENS = None
|
|
||||||
|
|
||||||
# LLM通用对话参数
|
|
||||||
TEMPERATURE = 0.7
|
TEMPERATURE = 0.7
|
||||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
|
||||||
|
|
||||||
ONLINE_LLM_MODEL = {
|
ONLINE_LLM_MODEL = {
|
||||||
# 线上模型。请在server_config中为每个在线API设置不同的端口
|
|
||||||
|
|
||||||
"openai-api": {
|
"openai-api": {
|
||||||
"model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-4",
|
||||||
"api_base_url": "https://api.openai.com/v1",
|
"api_base_url": "https://api.openai.com/v1",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"openai_proxy": "",
|
"openai_proxy": "",
|
||||||
},
|
},
|
||||||
|
|
||||||
# 具体注册及api key获取请前往 http://open.bigmodel.cn
|
# 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
|
||||||
"zhipu-api": {
|
"zhipu-api": {
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"version": "chatglm_turbo", # 可选包括 "chatglm_turbo"
|
"version": "glm-4",
|
||||||
"provider": "ChatGLMWorker",
|
"provider": "ChatGLMWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|
||||||
# 具体注册及api key获取请前往 https://api.minimax.chat/
|
# 具体注册及api key获取请前往 https://api.minimax.chat/
|
||||||
"minimax-api": {
|
"minimax-api": {
|
||||||
"group_id": "",
|
"group_id": "",
|
||||||
@ -72,13 +61,12 @@ ONLINE_LLM_MODEL = {
|
|||||||
"provider": "MiniMaxWorker",
|
"provider": "MiniMaxWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|
||||||
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
|
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
|
||||||
"xinghuo-api": {
|
"xinghuo-api": {
|
||||||
"APPID": "",
|
"APPID": "",
|
||||||
"APISecret": "",
|
"APISecret": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"version": "v1.5", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v1.5", "v2.0"
|
"version": "v3.0", # 你使用的讯飞星火大模型版本,可选包括 "v3.0", "v2.0", "v1.5"
|
||||||
"provider": "XingHuoWorker",
|
"provider": "XingHuoWorker",
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -93,8 +81,8 @@ ONLINE_LLM_MODEL = {
|
|||||||
|
|
||||||
# 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379
|
# 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379
|
||||||
"fangzhou-api": {
|
"fangzhou-api": {
|
||||||
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model", 更多的见文档模型支持列表中方舟部分。
|
"version": "chatglm-6b-model",
|
||||||
"version_url": "", # 可以不填写version,直接填写在方舟申请模型发布的API地址
|
"version_url": "",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"secret_key": "",
|
"secret_key": "",
|
||||||
"provider": "FangZhouWorker",
|
"provider": "FangZhouWorker",
|
||||||
@ -102,15 +90,15 @@ ONLINE_LLM_MODEL = {
|
|||||||
|
|
||||||
# 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
# 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
|
||||||
"qwen-api": {
|
"qwen-api": {
|
||||||
"version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus"
|
"version": "qwen-max",
|
||||||
"api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建
|
"api_key": "",
|
||||||
"provider": "QwenWorker",
|
"provider": "QwenWorker",
|
||||||
"embed_model": "text-embedding-v1" # embedding 模型名称
|
"embed_model": "text-embedding-v1" # embedding 模型名称
|
||||||
},
|
},
|
||||||
|
|
||||||
# 百川 API,申请方式请参考 https://www.baichuan-ai.com/home#api-enter
|
# 百川 API,申请方式请参考 https://www.baichuan-ai.com/home#api-enter
|
||||||
"baichuan-api": {
|
"baichuan-api": {
|
||||||
"version": "Baichuan2-53B", # 当前支持 "Baichuan2-53B", 见官方文档。
|
"version": "Baichuan2-53B",
|
||||||
"api_key": "",
|
"api_key": "",
|
||||||
"secret_key": "",
|
"secret_key": "",
|
||||||
"provider": "BaiChuanWorker",
|
"provider": "BaiChuanWorker",
|
||||||
@ -132,6 +120,11 @@ ONLINE_LLM_MODEL = {
|
|||||||
"secret_key": "",
|
"secret_key": "",
|
||||||
"provider": "TianGongWorker",
|
"provider": "TianGongWorker",
|
||||||
},
|
},
|
||||||
|
# Gemini API https://makersuite.google.com/app/apikey
|
||||||
|
"gemini-api": {
|
||||||
|
"api_key": "",
|
||||||
|
"provider": "GeminiWorker",
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,6 +136,7 @@ ONLINE_LLM_MODEL = {
|
|||||||
# - GanymedeNil/text2vec-large-chinese
|
# - GanymedeNil/text2vec-large-chinese
|
||||||
# - text2vec-large-chinese
|
# - text2vec-large-chinese
|
||||||
# 2.2 如果以上本地路径不存在,则使用huggingface模型
|
# 2.2 如果以上本地路径不存在,则使用huggingface模型
|
||||||
|
|
||||||
MODEL_PATH = {
|
MODEL_PATH = {
|
||||||
"embed_model": {
|
"embed_model": {
|
||||||
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
||||||
@ -169,55 +163,59 @@ MODEL_PATH = {
|
|||||||
},
|
},
|
||||||
|
|
||||||
"llm_model": {
|
"llm_model": {
|
||||||
# 以下部分模型并未完全测试,仅根据fastchat和vllm模型的模型列表推定支持
|
|
||||||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||||
|
|
||||||
"chatglm3-6b": "THUDM/chatglm3-6b",
|
"chatglm3-6b": "THUDM/chatglm3-6b",
|
||||||
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
||||||
"chatglm3-6b-base": "THUDM/chatglm3-6b-base",
|
|
||||||
|
|
||||||
"Qwen-1_8B": "Qwen/Qwen-1_8B",
|
"Orion-14B-Chat": "OrionStarAI/Orion-14B-Chat",
|
||||||
|
"Orion-14B-Chat-Plugin": "OrionStarAI/Orion-14B-Chat-Plugin",
|
||||||
|
"Orion-14B-LongChat": "OrionStarAI/Orion-14B-LongChat",
|
||||||
|
|
||||||
|
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
|
||||||
|
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
|
||||||
|
|
||||||
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
||||||
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
|
|
||||||
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
|
|
||||||
|
|
||||||
"Qwen-7B": "Qwen/Qwen-7B",
|
|
||||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||||
|
|
||||||
"Qwen-14B": "Qwen/Qwen-14B",
|
|
||||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||||
|
|
||||||
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
|
|
||||||
# 在新版的transformers下需要手动修改模型的config.json文件,在quantization_config字典中
|
|
||||||
# 增加`disable_exllama:true` 字段才能启动qwen的量化模型
|
|
||||||
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
|
|
||||||
|
|
||||||
"Qwen-72B": "Qwen/Qwen-72B",
|
|
||||||
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
||||||
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
|
|
||||||
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
|
|
||||||
|
|
||||||
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
|
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
|
||||||
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat",
|
|
||||||
|
|
||||||
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
|
||||||
"baichuan-13b": "baichuan-inc/Baichuan-13B",
|
|
||||||
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||||
|
"baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
"aquila-7b": "BAAI/Aquila-7B",
|
"baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat",
|
||||||
"aquilachat-7b": "BAAI/AquilaChat-7B",
|
|
||||||
|
|
||||||
"internlm-7b": "internlm/internlm-7b",
|
"internlm-7b": "internlm/internlm-7b",
|
||||||
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
||||||
|
"internlm2-chat-7b": "internlm/internlm2-chat-7b",
|
||||||
|
"internlm2-chat-20b": "internlm/internlm2-chat-20b",
|
||||||
|
|
||||||
|
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
|
||||||
|
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
|
||||||
|
|
||||||
|
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
|
||||||
|
|
||||||
|
"agentlm-7b": "THUDM/agentlm-7b",
|
||||||
|
"agentlm-13b": "THUDM/agentlm-13b",
|
||||||
|
"agentlm-70b": "THUDM/agentlm-70b",
|
||||||
|
|
||||||
"falcon-7b": "tiiuae/falcon-7b",
|
"falcon-7b": "tiiuae/falcon-7b",
|
||||||
"falcon-40b": "tiiuae/falcon-40b",
|
"falcon-40b": "tiiuae/falcon-40b",
|
||||||
"falcon-rw-7b": "tiiuae/falcon-rw-7b",
|
"falcon-rw-7b": "tiiuae/falcon-rw-7b",
|
||||||
|
|
||||||
|
"aquila-7b": "BAAI/Aquila-7B",
|
||||||
|
"aquilachat-7b": "BAAI/AquilaChat-7B",
|
||||||
|
"open_llama_13b": "openlm-research/open_llama_13b",
|
||||||
|
"vicuna-13b-v1.5": "lmsys/vicuna-13b-v1.5",
|
||||||
|
"koala": "young-geng/koala",
|
||||||
|
"mpt-7b": "mosaicml/mpt-7b",
|
||||||
|
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
|
||||||
|
"mpt-30b": "mosaicml/mpt-30b",
|
||||||
|
"opt-66b": "facebook/opt-66b",
|
||||||
|
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
||||||
"gpt2": "gpt2",
|
"gpt2": "gpt2",
|
||||||
"gpt2-xl": "gpt2-xl",
|
"gpt2-xl": "gpt2-xl",
|
||||||
|
|
||||||
"gpt-j-6b": "EleutherAI/gpt-j-6b",
|
"gpt-j-6b": "EleutherAI/gpt-j-6b",
|
||||||
"gpt4all-j": "nomic-ai/gpt4all-j",
|
"gpt4all-j": "nomic-ai/gpt4all-j",
|
||||||
"gpt-neox-20b": "EleutherAI/gpt-neox-20b",
|
"gpt-neox-20b": "EleutherAI/gpt-neox-20b",
|
||||||
@ -225,63 +223,51 @@ MODEL_PATH = {
|
|||||||
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
||||||
"dolly-v2-12b": "databricks/dolly-v2-12b",
|
"dolly-v2-12b": "databricks/dolly-v2-12b",
|
||||||
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
|
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
|
||||||
|
|
||||||
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
|
|
||||||
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
|
|
||||||
"open_llama_13b": "openlm-research/open_llama_13b",
|
|
||||||
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
|
|
||||||
"koala": "young-geng/koala",
|
|
||||||
|
|
||||||
"mpt-7b": "mosaicml/mpt-7b",
|
|
||||||
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
|
|
||||||
"mpt-30b": "mosaicml/mpt-30b",
|
|
||||||
"opt-66b": "facebook/opt-66b",
|
|
||||||
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
|
||||||
|
|
||||||
"agentlm-7b": "THUDM/agentlm-7b",
|
|
||||||
"agentlm-13b": "THUDM/agentlm-13b",
|
|
||||||
"agentlm-70b": "THUDM/agentlm-70b",
|
|
||||||
|
|
||||||
"Yi-34B-Chat": "01-ai/Yi-34B-Chat",
|
|
||||||
},
|
},
|
||||||
"reranker":{
|
|
||||||
"bge-reranker-large":"BAAI/bge-reranker-large",
|
"reranker": {
|
||||||
"bge-reranker-base":"BAAI/bge-reranker-base",
|
"bge-reranker-large": "BAAI/bge-reranker-large",
|
||||||
#TODO 增加在线reranker,如cohere
|
"bge-reranker-base": "BAAI/bge-reranker-base",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 通常情况下不需要更改以下内容
|
# 通常情况下不需要更改以下内容
|
||||||
|
|
||||||
# nltk 模型存储路径
|
# nltk 模型存储路径
|
||||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||||
|
|
||||||
|
# 使用VLLM可能导致模型推理能力下降,无法完成Agent任务
|
||||||
VLLM_MODEL_DICT = {
|
VLLM_MODEL_DICT = {
|
||||||
"aquila-7b": "BAAI/Aquila-7B",
|
|
||||||
"aquilachat-7b": "BAAI/AquilaChat-7B",
|
|
||||||
|
|
||||||
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
|
||||||
"baichuan-13b": "baichuan-inc/Baichuan-13B",
|
|
||||||
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
|
||||||
|
|
||||||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||||
"chatglm3-6b": "THUDM/chatglm3-6b",
|
"chatglm3-6b": "THUDM/chatglm3-6b",
|
||||||
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
|
||||||
|
|
||||||
|
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
|
||||||
|
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
|
||||||
|
|
||||||
|
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
||||||
|
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||||
|
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||||
|
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
||||||
|
|
||||||
|
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
|
||||||
|
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||||
|
"baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
|
||||||
|
"baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||||
|
|
||||||
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
|
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
|
||||||
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
|
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
|
||||||
|
|
||||||
# 注意:bloom系列的tokenizer与model是分离的,因此虽然vllm支持,但与fschat框架不兼容
|
|
||||||
# "bloom": "bigscience/bloom",
|
|
||||||
# "bloomz": "bigscience/bloomz",
|
|
||||||
# "bloomz-560m": "bigscience/bloomz-560m",
|
|
||||||
# "bloomz-7b1": "bigscience/bloomz-7b1",
|
|
||||||
# "bloomz-1b7": "bigscience/bloomz-1b7",
|
|
||||||
|
|
||||||
"internlm-7b": "internlm/internlm-7b",
|
"internlm-7b": "internlm/internlm-7b",
|
||||||
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
"internlm-chat-7b": "internlm/internlm-chat-7b",
|
||||||
|
"internlm2-chat-7b": "internlm/Models/internlm2-chat-7b",
|
||||||
|
"internlm2-chat-20b": "internlm/Models/internlm2-chat-20b",
|
||||||
|
|
||||||
|
"aquila-7b": "BAAI/Aquila-7B",
|
||||||
|
"aquilachat-7b": "BAAI/AquilaChat-7B",
|
||||||
|
|
||||||
"falcon-7b": "tiiuae/falcon-7b",
|
"falcon-7b": "tiiuae/falcon-7b",
|
||||||
"falcon-40b": "tiiuae/falcon-40b",
|
"falcon-40b": "tiiuae/falcon-40b",
|
||||||
"falcon-rw-7b": "tiiuae/falcon-rw-7b",
|
"falcon-rw-7b": "tiiuae/falcon-rw-7b",
|
||||||
@ -294,8 +280,6 @@ VLLM_MODEL_DICT = {
|
|||||||
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
||||||
"dolly-v2-12b": "databricks/dolly-v2-12b",
|
"dolly-v2-12b": "databricks/dolly-v2-12b",
|
||||||
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
|
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
|
||||||
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
|
|
||||||
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
|
|
||||||
"open_llama_13b": "openlm-research/open_llama_13b",
|
"open_llama_13b": "openlm-research/open_llama_13b",
|
||||||
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
|
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
|
||||||
"koala": "young-geng/koala",
|
"koala": "young-geng/koala",
|
||||||
@ -305,37 +289,14 @@ VLLM_MODEL_DICT = {
|
|||||||
"opt-66b": "facebook/opt-66b",
|
"opt-66b": "facebook/opt-66b",
|
||||||
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
|
||||||
|
|
||||||
"Qwen-1_8B": "Qwen/Qwen-1_8B",
|
|
||||||
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
|
|
||||||
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
|
|
||||||
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
|
|
||||||
|
|
||||||
"Qwen-7B": "Qwen/Qwen-7B",
|
|
||||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
|
||||||
|
|
||||||
"Qwen-14B": "Qwen/Qwen-14B",
|
|
||||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
|
||||||
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
|
|
||||||
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
|
|
||||||
|
|
||||||
"Qwen-72B": "Qwen/Qwen-72B",
|
|
||||||
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
|
|
||||||
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
|
|
||||||
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
|
|
||||||
|
|
||||||
"agentlm-7b": "THUDM/agentlm-7b",
|
|
||||||
"agentlm-13b": "THUDM/agentlm-13b",
|
|
||||||
"agentlm-70b": "THUDM/agentlm-70b",
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 你认为支持Agent能力的模型,可以在这里添加,添加后不会出现可视化界面的警告
|
|
||||||
# 经过我们测试,原生支持Agent的模型仅有以下几个
|
|
||||||
SUPPORT_AGENT_MODEL = [
|
SUPPORT_AGENT_MODEL = [
|
||||||
"azure-api",
|
"openai-api", # GPT4 模型
|
||||||
"openai-api",
|
"qwen-api", # Qwen Max模型
|
||||||
"qwen-api",
|
"zhipu-api", # 智谱AI GLM4模型
|
||||||
"Qwen",
|
"Qwen", # 所有Qwen系列本地模型
|
||||||
"chatglm3",
|
"chatglm3-6b",
|
||||||
"xinghuo-api",
|
"internlm2-chat-20b",
|
||||||
|
"Orion-14B-Chat-Plugin",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -40,8 +40,6 @@ FSCHAT_MODEL_WORKERS = {
|
|||||||
"device": LLM_DEVICE,
|
"device": LLM_DEVICE,
|
||||||
# False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题,参见doc/FAQ
|
# False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题,参见doc/FAQ
|
||||||
# vllm对一些模型支持还不成熟,暂时默认关闭
|
# vllm对一些模型支持还不成熟,暂时默认关闭
|
||||||
# fschat=0.2.33的代码有bug, 如需使用,源码修改fastchat.server.vllm_worker,
|
|
||||||
# 将103行中sampling_params = SamplingParams的参数stop=list(stop)修改为stop= [i for i in stop if i!=""]
|
|
||||||
"infer_turbo": False,
|
"infer_turbo": False,
|
||||||
|
|
||||||
# model_worker多卡加载需要配置的参数
|
# model_worker多卡加载需要配置的参数
|
||||||
@ -92,11 +90,10 @@ FSCHAT_MODEL_WORKERS = {
|
|||||||
# 'disable_log_requests': False
|
# 'disable_log_requests': False
|
||||||
|
|
||||||
},
|
},
|
||||||
# 可以如下示例方式更改默认配置
|
"Qwen-1_8B-Chat": {
|
||||||
# "Qwen-1_8B-Chat": { # 使用default中的IP和端口
|
"device": "cpu",
|
||||||
# "device": "cpu",
|
},
|
||||||
# },
|
"chatglm3-6b": {
|
||||||
"chatglm3-6b": { # 使用default中的IP和端口
|
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -128,14 +125,11 @@ FSCHAT_MODEL_WORKERS = {
|
|||||||
"tiangong-api": {
|
"tiangong-api": {
|
||||||
"port": 21009,
|
"port": 21009,
|
||||||
},
|
},
|
||||||
|
"gemini-api": {
|
||||||
|
"port": 21010,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# fastchat multi model worker server
|
|
||||||
FSCHAT_MULTI_MODEL_WORKERS = {
|
|
||||||
# TODO:
|
|
||||||
}
|
|
||||||
|
|
||||||
# fastchat controller server
|
|
||||||
FSCHAT_CONTROLLER = {
|
FSCHAT_CONTROLLER = {
|
||||||
"host": DEFAULT_BIND_HOST,
|
"host": DEFAULT_BIND_HOST,
|
||||||
"port": 20001,
|
"port": 20001,
|
||||||
|
|||||||
@ -1,2 +1,4 @@
|
|||||||
from .mypdfloader import RapidOCRPDFLoader
|
from .mypdfloader import RapidOCRPDFLoader
|
||||||
from .myimgloader import RapidOCRLoader
|
from .myimgloader import RapidOCRLoader
|
||||||
|
from .mydocloader import RapidOCRDocLoader
|
||||||
|
from .mypptloader import RapidOCRPPTLoader
|
||||||
|
|||||||
71
document_loaders/mydocloader.py
Normal file
71
document_loaders/mydocloader.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
from typing import List
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class RapidOCRDocLoader(UnstructuredFileLoader):
|
||||||
|
def _get_elements(self) -> List:
|
||||||
|
def doc2text(filepath):
|
||||||
|
from docx.table import _Cell, Table
|
||||||
|
from docx.oxml.table import CT_Tbl
|
||||||
|
from docx.oxml.text.paragraph import CT_P
|
||||||
|
from docx.text.paragraph import Paragraph
|
||||||
|
from docx import Document, ImagePart
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
import numpy as np
|
||||||
|
from rapidocr_onnxruntime import RapidOCR
|
||||||
|
ocr = RapidOCR()
|
||||||
|
doc = Document(filepath)
|
||||||
|
resp = ""
|
||||||
|
|
||||||
|
def iter_block_items(parent):
|
||||||
|
from docx.document import Document
|
||||||
|
if isinstance(parent, Document):
|
||||||
|
parent_elm = parent.element.body
|
||||||
|
elif isinstance(parent, _Cell):
|
||||||
|
parent_elm = parent._tc
|
||||||
|
else:
|
||||||
|
raise ValueError("RapidOCRDocLoader parse fail")
|
||||||
|
|
||||||
|
for child in parent_elm.iterchildren():
|
||||||
|
if isinstance(child, CT_P):
|
||||||
|
yield Paragraph(child, parent)
|
||||||
|
elif isinstance(child, CT_Tbl):
|
||||||
|
yield Table(child, parent)
|
||||||
|
|
||||||
|
b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables),
|
||||||
|
desc="RapidOCRDocLoader block index: 0")
|
||||||
|
for i, block in enumerate(iter_block_items(doc)):
|
||||||
|
b_unit.set_description(
|
||||||
|
"RapidOCRDocLoader block index: {}".format(i))
|
||||||
|
b_unit.refresh()
|
||||||
|
if isinstance(block, Paragraph):
|
||||||
|
resp += block.text.strip() + "\n"
|
||||||
|
images = block._element.xpath('.//pic:pic') # 获取所有图片
|
||||||
|
for image in images:
|
||||||
|
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
|
||||||
|
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
|
||||||
|
if isinstance(part, ImagePart):
|
||||||
|
image = Image.open(BytesIO(part._blob))
|
||||||
|
result, _ = ocr(np.array(image))
|
||||||
|
if result:
|
||||||
|
ocr_result = [line[1] for line in result]
|
||||||
|
resp += "\n".join(ocr_result)
|
||||||
|
elif isinstance(block, Table):
|
||||||
|
for row in block.rows:
|
||||||
|
for cell in row.cells:
|
||||||
|
for paragraph in cell.paragraphs:
|
||||||
|
resp += paragraph.text.strip() + "\n"
|
||||||
|
b_unit.update(1)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
text = doc2text(self.file_path)
|
||||||
|
from unstructured.partition.text import partition_text
|
||||||
|
return partition_text(text=text, **self.unstructured_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx")
|
||||||
|
docs = loader.load()
|
||||||
|
print(docs)
|
||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
from configs import PDF_OCR_THRESHOLD
|
||||||
from document_loaders.ocr import get_ocr
|
from document_loaders.ocr import get_ocr
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
@ -15,23 +16,25 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
|
|||||||
|
|
||||||
b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0")
|
b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0")
|
||||||
for i, page in enumerate(doc):
|
for i, page in enumerate(doc):
|
||||||
|
|
||||||
# 更新描述
|
|
||||||
b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i))
|
b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i))
|
||||||
# 立即显示进度条更新结果
|
|
||||||
b_unit.refresh()
|
b_unit.refresh()
|
||||||
# TODO: 依据文本与图片顺序调整处理方式
|
|
||||||
text = page.get_text("")
|
text = page.get_text("")
|
||||||
resp += text + "\n"
|
resp += text + "\n"
|
||||||
|
|
||||||
img_list = page.get_images()
|
img_list = page.get_image_info(xrefs=True)
|
||||||
for img in img_list:
|
for img in img_list:
|
||||||
pix = fitz.Pixmap(doc, img[0])
|
if xref := img.get("xref"):
|
||||||
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
|
bbox = img["bbox"]
|
||||||
result, _ = ocr(img_array)
|
# 检查图片尺寸是否超过设定的阈值
|
||||||
if result:
|
if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0]
|
||||||
ocr_result = [line[1] for line in result]
|
or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]):
|
||||||
resp += "\n".join(ocr_result)
|
continue
|
||||||
|
pix = fitz.Pixmap(doc, xref)
|
||||||
|
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
|
||||||
|
result, _ = ocr(img_array)
|
||||||
|
if result:
|
||||||
|
ocr_result = [line[1] for line in result]
|
||||||
|
resp += "\n".join(ocr_result)
|
||||||
|
|
||||||
# 更新进度
|
# 更新进度
|
||||||
b_unit.update(1)
|
b_unit.update(1)
|
||||||
|
|||||||
59
document_loaders/mypptloader.py
Normal file
59
document_loaders/mypptloader.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from langchain.document_loaders.unstructured import UnstructuredFileLoader
|
||||||
|
from typing import List
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class RapidOCRPPTLoader(UnstructuredFileLoader):
|
||||||
|
def _get_elements(self) -> List:
|
||||||
|
def ppt2text(filepath):
|
||||||
|
from pptx import Presentation
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from io import BytesIO
|
||||||
|
from rapidocr_onnxruntime import RapidOCR
|
||||||
|
ocr = RapidOCR()
|
||||||
|
prs = Presentation(filepath)
|
||||||
|
resp = ""
|
||||||
|
|
||||||
|
def extract_text(shape):
|
||||||
|
nonlocal resp
|
||||||
|
if shape.has_text_frame:
|
||||||
|
resp += shape.text.strip() + "\n"
|
||||||
|
if shape.has_table:
|
||||||
|
for row in shape.table.rows:
|
||||||
|
for cell in row.cells:
|
||||||
|
for paragraph in cell.text_frame.paragraphs:
|
||||||
|
resp += paragraph.text.strip() + "\n"
|
||||||
|
if shape.shape_type == 13: # 13 表示图片
|
||||||
|
image = Image.open(BytesIO(shape.image.blob))
|
||||||
|
result, _ = ocr(np.array(image))
|
||||||
|
if result:
|
||||||
|
ocr_result = [line[1] for line in result]
|
||||||
|
resp += "\n".join(ocr_result)
|
||||||
|
elif shape.shape_type == 6: # 6 表示组合
|
||||||
|
for child_shape in shape.shapes:
|
||||||
|
extract_text(child_shape)
|
||||||
|
|
||||||
|
b_unit = tqdm.tqdm(total=len(prs.slides),
|
||||||
|
desc="RapidOCRPPTLoader slide index: 1")
|
||||||
|
# 遍历所有幻灯片
|
||||||
|
for slide_number, slide in enumerate(prs.slides, start=1):
|
||||||
|
b_unit.set_description(
|
||||||
|
"RapidOCRPPTLoader slide index: {}".format(slide_number))
|
||||||
|
b_unit.refresh()
|
||||||
|
sorted_shapes = sorted(slide.shapes,
|
||||||
|
key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历
|
||||||
|
for shape in sorted_shapes:
|
||||||
|
extract_text(shape)
|
||||||
|
b_unit.update(1)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
text = ppt2text(self.file_path)
|
||||||
|
from unstructured.partition.text import partition_text
|
||||||
|
return partition_text(text=text, **self.unstructured_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx")
|
||||||
|
docs = loader.load()
|
||||||
|
print(docs)
|
||||||
@ -7,31 +7,35 @@
|
|||||||
保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
|
保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
|
||||||
'''
|
'''
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append("..")
|
sys.path.append("..")
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from configs import (
|
from configs import (
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
EMBEDDING_KEYWORD_FILE,
|
EMBEDDING_KEYWORD_FILE,
|
||||||
)
|
)
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import save_model
|
from safetensors.torch import save_model
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from langchain_core._api import deprecated
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.3.0",
|
||||||
|
message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
|
||||||
|
removal="0.3.0"
|
||||||
|
)
|
||||||
def get_keyword_embedding(bert_model, tokenizer, key_words):
|
def get_keyword_embedding(bert_model, tokenizer, key_words):
|
||||||
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
|
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
|
||||||
|
|
||||||
# No need to manually convert to tensor as we've set return_tensors="pt"
|
|
||||||
input_ids = tokenizer_output['input_ids']
|
input_ids = tokenizer_output['input_ids']
|
||||||
|
|
||||||
# Remove the first and last token for each sequence in the batch
|
|
||||||
input_ids = input_ids[:, 1:-1]
|
input_ids = input_ids[:, 1:-1]
|
||||||
|
|
||||||
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
|
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
|
||||||
keyword_embedding = torch.mean(keyword_embedding, 1)
|
keyword_embedding = torch.mean(keyword_embedding, 1)
|
||||||
|
|
||||||
return keyword_embedding
|
return keyword_embedding
|
||||||
|
|
||||||
|
|
||||||
@ -47,14 +51,11 @@ def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", out
|
|||||||
bert_model = word_embedding_model.auto_model
|
bert_model = word_embedding_model.auto_model
|
||||||
tokenizer = word_embedding_model.tokenizer
|
tokenizer = word_embedding_model.tokenizer
|
||||||
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
|
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
|
||||||
# key_words_embedding = st_model.encode(key_words)
|
|
||||||
|
|
||||||
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
||||||
embedding_weight_len = len(embedding_weight)
|
embedding_weight_len = len(embedding_weight)
|
||||||
tokenizer.add_tokens(key_words)
|
tokenizer.add_tokens(key_words)
|
||||||
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
|
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
|
||||||
|
|
||||||
# key_words_embedding_tensor = torch.from_numpy(key_words_embedding)
|
|
||||||
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
embedding_weight = bert_model.embeddings.word_embeddings.weight
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
|
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
|
||||||
@ -76,46 +77,3 @@ def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
|
|||||||
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
|
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
|
||||||
output_model_path = os.path.join(model_parent_directory, output_model_name)
|
output_model_path = os.path.join(model_parent_directory, output_model_name)
|
||||||
add_keyword_to_model(model_name, keyword_file, output_model_path)
|
add_keyword_to_model(model_name, keyword_file, output_model_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
add_keyword_to_embedding_model(EMBEDDING_KEYWORD_FILE)
|
|
||||||
|
|
||||||
# input_model_name = ""
|
|
||||||
# output_model_path = ""
|
|
||||||
# # 以下为加入关键字前后tokenizer的测试用例对比
|
|
||||||
# def print_token_ids(output, tokenizer, sentences):
|
|
||||||
# for idx, ids in enumerate(output['input_ids']):
|
|
||||||
# print(f'sentence={sentences[idx]}')
|
|
||||||
# print(f'ids={ids}')
|
|
||||||
# for id in ids:
|
|
||||||
# decoded_id = tokenizer.decode(id)
|
|
||||||
# print(f' {decoded_id}->{id}')
|
|
||||||
#
|
|
||||||
# sentences = [
|
|
||||||
# '数据科学与大数据技术',
|
|
||||||
# 'Langchain-Chatchat'
|
|
||||||
# ]
|
|
||||||
#
|
|
||||||
# st_no_keywords = SentenceTransformer(input_model_name)
|
|
||||||
# tokenizer_without_keywords = st_no_keywords.tokenizer
|
|
||||||
# print("===== tokenizer with no keywords added =====")
|
|
||||||
# output = tokenizer_without_keywords(sentences)
|
|
||||||
# print_token_ids(output, tokenizer_without_keywords, sentences)
|
|
||||||
# print(f'-------- embedding with no keywords added -----')
|
|
||||||
# embeddings = st_no_keywords.encode(sentences)
|
|
||||||
# print(embeddings)
|
|
||||||
#
|
|
||||||
# print("--------------------------------------------")
|
|
||||||
# print("--------------------------------------------")
|
|
||||||
# print("--------------------------------------------")
|
|
||||||
#
|
|
||||||
# st_with_keywords = SentenceTransformer(output_model_path)
|
|
||||||
# tokenizer_with_keywords = st_with_keywords.tokenizer
|
|
||||||
# print("===== tokenizer with keyword added =====")
|
|
||||||
# output = tokenizer_with_keywords(sentences)
|
|
||||||
# print_token_ids(output, tokenizer_with_keywords, sentences)
|
|
||||||
#
|
|
||||||
# print(f'-------- embedding with keywords added -----')
|
|
||||||
# embeddings = st_with_keywords.encode(sentences)
|
|
||||||
# print(embeddings)
|
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 272 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 218 KiB After Width: | Height: | Size: 318 KiB |
@ -1 +1 @@
|
|||||||
Subproject commit 2f24adb218f23eab00d7fcd7ccf5072f2f35cb3c
|
Subproject commit 28f664aa08f8191a70339c9ecbe7a89b35a1032a
|
||||||
113
requirements.txt
113
requirements.txt
@ -1,77 +1,66 @@
|
|||||||
# API requirements
|
torch==2.1.2
|
||||||
|
torchvision==0.16.2
|
||||||
torch~=2.1.2
|
torchaudio==2.1.2
|
||||||
torchvision~=0.16.2
|
|
||||||
torchaudio~=2.1.2
|
|
||||||
xformers==0.0.23.post1
|
xformers==0.0.23.post1
|
||||||
transformers==4.36.2
|
transformers==4.36.2
|
||||||
sentence_transformers==2.2.2
|
sentence_transformers==2.2.2
|
||||||
|
|
||||||
langchain==0.0.354
|
langchain==0.0.354
|
||||||
langchain-experimental==0.0.47
|
langchain-experimental==0.0.47
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
fschat==0.2.34
|
fschat==0.2.35
|
||||||
openai~=1.7.1
|
openai==1.9.0
|
||||||
fastapi~=0.108.0
|
fastapi==0.109.0
|
||||||
sse_starlette==1.8.2
|
sse_starlette==1.8.2
|
||||||
nltk>=3.8.1
|
nltk==3.8.1
|
||||||
uvicorn>=0.24.0.post1
|
uvicorn==0.24.0.post1
|
||||||
starlette~=0.32.0
|
starlette==0.35.0
|
||||||
unstructured[all-docs]==0.11.0
|
unstructured[all-docs] # ==0.11.8
|
||||||
python-magic-bin; sys_platform == 'win32'
|
python-magic-bin; sys_platform == 'win32'
|
||||||
SQLAlchemy==2.0.19
|
SQLAlchemy==2.0.25
|
||||||
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
|
faiss-cpu==1.7.4
|
||||||
accelerate~=0.24.1
|
accelerate==0.24.1
|
||||||
spacy~=3.7.2
|
spacy==3.7.2
|
||||||
PyMuPDF~=1.23.8
|
PyMuPDF==1.23.16
|
||||||
rapidocr_onnxruntime==1.3.8
|
rapidocr_onnxruntime==1.3.8
|
||||||
requests~=2.31.0
|
requests==2.31.0
|
||||||
pathlib~=1.0.1
|
pathlib==1.0.1
|
||||||
pytest~=7.4.3
|
pytest==7.4.3
|
||||||
numexpr~=2.8.6 # max version for py38
|
numexpr==2.8.6
|
||||||
strsimpy~=0.2.1
|
strsimpy==0.2.1
|
||||||
markdownify~=0.11.6
|
markdownify==0.11.6
|
||||||
tiktoken~=0.5.2
|
tiktoken==0.5.2
|
||||||
tqdm>=4.66.1
|
tqdm==4.66.1
|
||||||
websockets>=12.0
|
websockets==12.0
|
||||||
numpy~=1.24.4
|
numpy==1.24.4
|
||||||
pandas~=2.0.3
|
pandas==2.0.3
|
||||||
einops>=0.7.0
|
einops==0.7.0
|
||||||
transformers_stream_generator==0.0.4
|
transformers_stream_generator==0.0.4
|
||||||
vllm==0.2.6; sys_platform == "linux"
|
vllm==0.2.7; sys_platform == "linux"
|
||||||
httpx[brotli,http2,socks]==0.25.2
|
llama-index==0.9.35
|
||||||
llama-index
|
|
||||||
|
|
||||||
# optional document loaders
|
|
||||||
|
|
||||||
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
|
|
||||||
jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
|
|
||||||
beautifulsoup4~=4.12.2 # for .mhtml files
|
|
||||||
pysrt~=1.1.2
|
|
||||||
|
|
||||||
# Online api libs dependencies
|
|
||||||
|
|
||||||
zhipuai==1.0.7 # zhipu
|
|
||||||
dashscope==1.13.6 # qwen
|
|
||||||
# volcengine>=1.0.119 # fangzhou
|
|
||||||
|
|
||||||
|
#jq==1.6.0
|
||||||
|
# beautifulsoup4==4.12.2
|
||||||
|
# pysrt==1.1.2
|
||||||
|
# dashscope==1.13.6 # qwen
|
||||||
|
# volcengine==1.0.119 # fangzhou
|
||||||
# uncomment libs if you want to use corresponding vector store
|
# uncomment libs if you want to use corresponding vector store
|
||||||
# pymilvus>=2.3.4
|
# pymilvus==2.3.4
|
||||||
# psycopg2==2.9.9
|
# psycopg2==2.9.9
|
||||||
# pgvector>=0.2.4
|
# pgvector==0.2.4
|
||||||
|
#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat
|
||||||
|
#autoawq==0.1.8 # For Int4
|
||||||
|
#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files
|
||||||
|
|
||||||
# Agent and Search Tools
|
arxiv==2.1.0
|
||||||
|
youtube-search==2.1.2
|
||||||
arxiv~=2.1.0
|
duckduckgo-search==3.9.9
|
||||||
youtube-search~=2.1.2
|
metaphor-python==0.1.23
|
||||||
duckduckgo-search~=3.9.9
|
streamlit==1.30.0
|
||||||
metaphor-python~=0.1.23
|
streamlit-option-menu==0.3.12
|
||||||
|
streamlit-antd-components==0.3.1
|
||||||
# WebUI requirements
|
|
||||||
|
|
||||||
streamlit~=1.29.0
|
|
||||||
streamlit-option-menu>=0.3.6
|
|
||||||
streamlit-chatbox==1.1.11
|
streamlit-chatbox==1.1.11
|
||||||
streamlit-modal>=0.1.0
|
streamlit-modal==0.1.0
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid==0.3.4.post3
|
||||||
watchdog>=3.0.0
|
httpx==0.26.0
|
||||||
|
watchdog==3.0.0
|
||||||
|
jwt==1.3.1
|
||||||
@ -1,24 +1,23 @@
|
|||||||
torch~=2.1.2
|
torch~=2.1.2
|
||||||
torchvision~=0.16.2
|
torchvision~=0.16.2
|
||||||
torchaudio~=2.1.2
|
torchaudio~=2.1.2
|
||||||
xformers==0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
transformers==4.36.2
|
transformers==4.36.2
|
||||||
sentence_transformers==2.2.2
|
sentence_transformers==2.2.2
|
||||||
|
|
||||||
langchain==0.0.354
|
langchain==0.0.354
|
||||||
langchain-experimental==0.0.47
|
langchain-experimental==0.0.47
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
fschat==0.2.34
|
fschat==0.2.35
|
||||||
openai~=1.7.1
|
openai~=1.9.0
|
||||||
fastapi~=0.108.0
|
fastapi~=0.109.0
|
||||||
sse_starlette==1.8.2
|
sse_starlette==1.8.2
|
||||||
nltk>=3.8.1
|
nltk>=3.8.1
|
||||||
uvicorn>=0.24.0.post1
|
uvicorn>=0.24.0.post1
|
||||||
starlette~=0.32.0
|
starlette~=0.35.0
|
||||||
unstructured[all-docs]==0.11.0
|
unstructured[all-docs]==0.11.0
|
||||||
python-magic-bin; sys_platform == 'win32'
|
python-magic-bin; sys_platform == 'win32'
|
||||||
SQLAlchemy==2.0.19
|
SQLAlchemy==2.0.19
|
||||||
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
|
faiss-cpu~=1.7.4
|
||||||
accelerate~=0.24.1
|
accelerate~=0.24.1
|
||||||
spacy~=3.7.2
|
spacy~=3.7.2
|
||||||
PyMuPDF~=1.23.8
|
PyMuPDF~=1.23.8
|
||||||
@ -26,7 +25,7 @@ rapidocr_onnxruntime==1.3.8
|
|||||||
requests~=2.31.0
|
requests~=2.31.0
|
||||||
pathlib~=1.0.1
|
pathlib~=1.0.1
|
||||||
pytest~=7.4.3
|
pytest~=7.4.3
|
||||||
numexpr~=2.8.6 # max version for py38
|
numexpr~=2.8.6
|
||||||
strsimpy~=0.2.1
|
strsimpy~=0.2.1
|
||||||
markdownify~=0.11.6
|
markdownify~=0.11.6
|
||||||
tiktoken~=0.5.2
|
tiktoken~=0.5.2
|
||||||
@ -36,31 +35,23 @@ numpy~=1.24.4
|
|||||||
pandas~=2.0.3
|
pandas~=2.0.3
|
||||||
einops>=0.7.0
|
einops>=0.7.0
|
||||||
transformers_stream_generator==0.0.4
|
transformers_stream_generator==0.0.4
|
||||||
vllm==0.2.6; sys_platform == "linux"
|
vllm==0.2.7; sys_platform == "linux"
|
||||||
httpx[brotli,http2,socks]==0.25.2
|
httpx==0.26.0
|
||||||
llama-index
|
llama-index==0.9.35
|
||||||
|
|
||||||
# optional document loaders
|
# jq==1.6.0
|
||||||
|
# beautifulsoup4~=4.12.2
|
||||||
|
# pysrt~=1.1.2
|
||||||
|
# dashscope==1.13.6
|
||||||
|
# arxiv~=2.1.0
|
||||||
|
# youtube-search~=2.1.2
|
||||||
|
# duckduckgo-search~=3.9.9
|
||||||
|
# metaphor-python~=0.1.23
|
||||||
|
|
||||||
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
|
# volcengine>=1.0.119
|
||||||
jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
|
|
||||||
beautifulsoup4~=4.12.2 # for .mhtml files
|
|
||||||
pysrt~=1.1.2
|
|
||||||
|
|
||||||
# Online api libs dependencies
|
|
||||||
|
|
||||||
zhipuai==1.0.7 # zhipu
|
|
||||||
dashscope==1.13.6 # qwen
|
|
||||||
# volcengine>=1.0.119 # fangzhou
|
|
||||||
|
|
||||||
# uncomment libs if you want to use corresponding vector store
|
|
||||||
# pymilvus>=2.3.4
|
# pymilvus>=2.3.4
|
||||||
# psycopg2==2.9.9
|
# psycopg2==2.9.9
|
||||||
# pgvector>=0.2.4
|
# pgvector>=0.2.4
|
||||||
|
#flash-attn==2.4.2 # For Orion-14B-Chat and Qwen-14B-Chat
|
||||||
# Agent and Search Tools
|
#autoawq==0.1.8 # For Int4
|
||||||
|
#rapidocr_paddle[gpu]==1.3.11 # gpu accelleration for ocr of pdf and image files
|
||||||
arxiv~=2.1.0
|
|
||||||
youtube-search~=2.1.2
|
|
||||||
duckduckgo-search~=3.9.9
|
|
||||||
metaphor-python~=0.1.23
|
|
||||||
@ -1,64 +1,33 @@
|
|||||||
# API requirements
|
|
||||||
|
|
||||||
langchain==0.0.354
|
langchain==0.0.354
|
||||||
langchain-experimental==0.0.47
|
langchain-experimental==0.0.47
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
fschat==0.2.34
|
fschat~=0.2.35
|
||||||
openai~=1.7.1
|
openai~=1.9.0
|
||||||
fastapi~=0.108.0
|
fastapi~=0.109.0
|
||||||
sse_starlette==1.8.2
|
sse_starlette~=1.8.2
|
||||||
nltk>=3.8.1
|
nltk~=3.8.1
|
||||||
uvicorn>=0.24.0.post1
|
uvicorn~=0.24.0.post1
|
||||||
starlette~=0.32.0
|
starlette~=0.35.0
|
||||||
unstructured[all-docs]==0.11.0
|
unstructured[all-docs]~=0.12.0
|
||||||
python-magic-bin; sys_platform == 'win32'
|
python-magic-bin; sys_platform == 'win32'
|
||||||
SQLAlchemy==2.0.19
|
SQLAlchemy~=2.0.25
|
||||||
faiss-cpu~=1.7.4
|
faiss-cpu~=1.7.4
|
||||||
|
accelerate~=0.24.1
|
||||||
|
spacy~=3.7.2
|
||||||
|
PyMuPDF~=1.23.16
|
||||||
|
rapidocr_onnxruntime~=1.3.8
|
||||||
requests~=2.31.0
|
requests~=2.31.0
|
||||||
pathlib~=1.0.1
|
pathlib~=1.0.1
|
||||||
pytest~=7.4.3
|
pytest~=7.4.3
|
||||||
numexpr~=2.8.6 # max version for py38
|
llama-index==0.9.35
|
||||||
strsimpy~=0.2.1
|
|
||||||
markdownify~=0.11.6
|
|
||||||
tiktoken~=0.5.2
|
|
||||||
tqdm>=4.66.1
|
|
||||||
websockets>=12.0
|
|
||||||
numpy~=1.24.4
|
|
||||||
pandas~=2.0.3
|
|
||||||
einops>=0.7.0
|
|
||||||
transformers_stream_generator==0.0.4
|
|
||||||
vllm==0.2.6; sys_platform == "linux"
|
|
||||||
httpx[brotli,http2,socks]==0.25.2
|
|
||||||
requests
|
|
||||||
pathlib
|
|
||||||
pytest
|
|
||||||
|
|
||||||
|
|
||||||
# Online api libs dependencies
|
|
||||||
|
|
||||||
zhipuai==1.0.7
|
|
||||||
dashscope==1.13.6
|
dashscope==1.13.6
|
||||||
# volcengine>=1.0.119
|
|
||||||
|
|
||||||
# uncomment libs if you want to use corresponding vector store
|
|
||||||
# pymilvus>=2.3.4
|
|
||||||
# psycopg2==2.9.9
|
|
||||||
# pgvector>=0.2.4
|
|
||||||
|
|
||||||
# Agent and Search Tools
|
|
||||||
|
|
||||||
arxiv~=2.1.0
|
arxiv~=2.1.0
|
||||||
youtube-search~=2.1.2
|
youtube-search~=2.1.2
|
||||||
duckduckgo-search~=3.9.9
|
duckduckgo-search~=3.9.9
|
||||||
metaphor-python~=0.1.23
|
metaphor-python~=0.1.23
|
||||||
|
watchdog~=3.0.0
|
||||||
# WebUI requirements
|
# volcengine>=1.0.119
|
||||||
|
# pymilvus>=2.3.4
|
||||||
streamlit>=1.29.0
|
# psycopg2==2.9.9
|
||||||
streamlit-option-menu>=0.3.6
|
# pgvector>=0.2.4
|
||||||
streamlit-antd-components>=0.3.0
|
|
||||||
streamlit-chatbox>=1.1.11
|
|
||||||
streamlit-modal>=0.1.0
|
|
||||||
streamlit-aggrid>=0.3.4.post3
|
|
||||||
httpx[brotli,http2,socks]>=0.25.2
|
|
||||||
watchdog>=3.0.0
|
|
||||||
@ -1,10 +1,8 @@
|
|||||||
# WebUI requirements
|
streamlit~=1.30.0
|
||||||
|
streamlit-option-menu~=0.3.12
|
||||||
streamlit>=1.29.0
|
streamlit-antd-components~=0.3.1
|
||||||
streamlit-option-menu>=0.3.6
|
streamlit-chatbox~=1.1.11
|
||||||
streamlit-antd-components>=0.3.0
|
streamlit-modal~=0.1.0
|
||||||
streamlit-chatbox>=1.1.11
|
streamlit-aggrid~=0.3.4.post3
|
||||||
streamlit-modal>=0.1.0
|
httpx~=0.26.0
|
||||||
streamlit-aggrid>=0.3.4.post3
|
watchdog~=3.0.0
|
||||||
httpx[brotli,http2,socks]>=0.25.2
|
|
||||||
watchdog>=3.0.0
|
|
||||||
@ -1,22 +1,19 @@
|
|||||||
"""
|
"""
|
||||||
This file is a modified version for ChatGLM3-6B the original ChatGLM3Agent.py file from the langchain repo.
|
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import yaml
|
|
||||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
|
||||||
from langchain.memory import ConversationBufferWindowMemory
|
|
||||||
from typing import Any, List, Sequence, Tuple, Optional, Union
|
|
||||||
import os
|
|
||||||
from langchain.agents.agent import Agent
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.prompts.chat import (
|
|
||||||
ChatPromptTemplate,
|
|
||||||
HumanMessagePromptTemplate,
|
|
||||||
SystemMessagePromptTemplate, MessagesPlaceholder,
|
|
||||||
)
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, List, Sequence, Tuple, Optional, Union
|
||||||
|
from pydantic.schema import model_schema
|
||||||
|
|
||||||
|
|
||||||
|
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||||
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
|
from langchain.agents.agent import Agent
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||||
from langchain.agents.agent import AgentOutputParser
|
from langchain.agents.agent import AgentOutputParser
|
||||||
from langchain.output_parsers import OutputFixingParser
|
from langchain.output_parsers import OutputFixingParser
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
@ -43,12 +40,18 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
|
|||||||
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
|
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
|
||||||
text = text[:first_index]
|
text = text[:first_index]
|
||||||
if "tool_call" in text:
|
if "tool_call" in text:
|
||||||
tool_name_end = text.find("```")
|
action_end = text.find("```")
|
||||||
tool_name = text[:tool_name_end].strip()
|
action = text[:action_end].strip()
|
||||||
input_para = text.split("='")[-1].split("'")[0]
|
params_str_start = text.find("(") + 1
|
||||||
|
params_str_end = text.rfind(")")
|
||||||
|
params_str = text[params_str_start:params_str_end]
|
||||||
|
|
||||||
|
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
|
||||||
|
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
|
||||||
|
|
||||||
action_json = {
|
action_json = {
|
||||||
"action": tool_name,
|
"action": action,
|
||||||
"action_input": input_para
|
"action_input": params
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
action_json = {
|
action_json = {
|
||||||
@ -109,10 +112,6 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
else:
|
else:
|
||||||
return agent_scratchpad
|
return agent_scratchpad
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_default_output_parser(
|
def _get_default_output_parser(
|
||||||
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
|
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
|
||||||
@ -121,7 +120,7 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _stop(self) -> List[str]:
|
def _stop(self) -> List[str]:
|
||||||
return ["```<observation>"]
|
return ["<|observation|>"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_prompt(
|
def create_prompt(
|
||||||
@ -131,44 +130,25 @@ class StructuredGLM3ChatAgent(Agent):
|
|||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[List[str]] = None,
|
||||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
def tool_config_from_file(tool_name, directory="server/agent/tools/"):
|
|
||||||
"""search tool yaml and return simplified json format"""
|
|
||||||
file_path = os.path.join(directory, f"{tool_name.lower()}.yaml")
|
|
||||||
try:
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
|
||||||
tool_config = yaml.safe_load(file)
|
|
||||||
# Simplify the structure if needed
|
|
||||||
simplified_config = {
|
|
||||||
"name": tool_config.get("name", ""),
|
|
||||||
"description": tool_config.get("description", ""),
|
|
||||||
"parameters": tool_config.get("parameters", {})
|
|
||||||
}
|
|
||||||
return simplified_config
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(f"File not found: {file_path}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"An error occurred while reading {file_path}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
tools_json = []
|
tools_json = []
|
||||||
tool_names = []
|
tool_names = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
tool_config = tool_config_from_file(tool.name)
|
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
|
||||||
if tool_config:
|
simplified_config_langchain = {
|
||||||
tools_json.append(tool_config)
|
"name": tool.name,
|
||||||
tool_names.append(tool.name)
|
"description": tool.description,
|
||||||
|
"parameters": tool_schema.get("properties", {})
|
||||||
# Format the tools for output
|
}
|
||||||
|
tools_json.append(simplified_config_langchain)
|
||||||
|
tool_names.append(tool.name)
|
||||||
formatted_tools = "\n".join([
|
formatted_tools = "\n".join([
|
||||||
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
|
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
|
||||||
for tool in tools_json
|
for tool in tools_json
|
||||||
])
|
])
|
||||||
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
|
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
|
||||||
|
|
||||||
template = prompt.format(tool_names=tool_names,
|
template = prompt.format(tool_names=tool_names,
|
||||||
tools=formatted_tools,
|
tools=formatted_tools,
|
||||||
history="{history}",
|
history="None",
|
||||||
input="{input}",
|
input="{input}",
|
||||||
agent_scratchpad="{agent_scratchpad}")
|
agent_scratchpad="{agent_scratchpad}")
|
||||||
|
|
||||||
@ -225,7 +205,6 @@ def initialize_glm3_agent(
|
|||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: str = None,
|
prompt: str = None,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
memory: Optional[ConversationBufferWindowMemory] = None,
|
memory: Optional[ConversationBufferWindowMemory] = None,
|
||||||
agent_kwargs: Optional[dict] = None,
|
agent_kwargs: Optional[dict] = None,
|
||||||
*,
|
*,
|
||||||
@ -238,14 +217,12 @@ def initialize_glm3_agent(
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
callback_manager=callback_manager, **agent_kwargs
|
**agent_kwargs
|
||||||
)
|
)
|
||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent_obj,
|
agent=agent_obj,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
callback_manager=callback_manager,
|
|
||||||
memory=memory,
|
memory=memory,
|
||||||
tags=tags_,
|
tags=tags_,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍
|
|
||||||
class ModelContainer:
|
class ModelContainer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.MODEL = None
|
self.MODEL = None
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from .search_knowledgebase_simple import search_knowledgebase_simple
|
|||||||
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
|
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
|
||||||
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
|
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
|
||||||
from .calculate import calculate, CalculatorInput
|
from .calculate import calculate, CalculatorInput
|
||||||
from .weather_check import weathercheck, WhetherSchema
|
from .weather_check import weathercheck, WeatherInput
|
||||||
from .shell import shell, ShellInput
|
from .shell import shell, ShellInput
|
||||||
from .search_internet import search_internet, SearchInternetInput
|
from .search_internet import search_internet, SearchInternetInput
|
||||||
from .wolfram import wolfram, WolframInput
|
from .wolfram import wolfram, WolframInput
|
||||||
|
|||||||
@ -1,10 +0,0 @@
|
|||||||
name: arxiv
|
|
||||||
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The search query title
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: calculate
|
|
||||||
description: Useful for when you need to answer questions about simple calculations
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The formula to be calculated
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: search_internet
|
|
||||||
description: Use this tool to surf internet and get information
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: Query for Internet search
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: search_knowledgebase_complex
|
|
||||||
description: Use this tool to search local knowledgebase and get information
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The query to be searched
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: search_youtube
|
|
||||||
description: Use this tools to search youtube videos
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: Query for Videos search
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: shell
|
|
||||||
description: Use Linux Shell to execute Linux commands
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The command to execute
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,338 +1,29 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
## 单独运行的时候需要添加
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
||||||
|
|
||||||
import re
|
|
||||||
import warnings
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForChainRun,
|
|
||||||
CallbackManagerForChainRun,
|
|
||||||
)
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.pydantic_v1 import Extra, root_validator
|
|
||||||
from langchain.schema import BasePromptTemplate
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
import requests
|
|
||||||
from typing import List, Any, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from server.agent import model_container
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
## 使用和风天气API查询天气
|
|
||||||
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
|
|
||||||
# key长这样,这里提供了示例的key,这个key没法使用,你需要自己去注册和风天气的账号,然后在这里填入你的key
|
|
||||||
|
|
||||||
|
|
||||||
_PROMPT_TEMPLATE = """
|
|
||||||
用户会提出一个关于天气的问题,你的目标是拆分出用户问题中的区,市 并按照我提供的工具回答。
|
|
||||||
例如 用户提出的问题是: 上海浦东未来1小时天气情况?
|
|
||||||
则 提取的市和区是: 上海 浦东
|
|
||||||
如果用户提出的问题是: 上海未来1小时天气情况?
|
|
||||||
则 提取的市和区是: 上海 None
|
|
||||||
请注意以下内容:
|
|
||||||
1. 如果你没有找到区的内容,则一定要使用 None 替代,否则程序无法运行
|
|
||||||
2. 如果用户没有指定市 则直接返回缺少信息
|
|
||||||
|
|
||||||
问题: ${{用户的问题}}
|
|
||||||
|
|
||||||
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
|
|
||||||
```text
|
|
||||||
|
|
||||||
${{拆分的市和区,中间用空格隔开}}
|
|
||||||
```
|
|
||||||
... weathercheck(市 区)...
|
|
||||||
```output
|
|
||||||
|
|
||||||
${{提取后的答案}}
|
|
||||||
```
|
|
||||||
答案: ${{答案}}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
这是一个例子:
|
|
||||||
问题: 上海浦东未来1小时天气情况?
|
|
||||||
|
|
||||||
|
|
||||||
```text
|
|
||||||
上海 浦东
|
|
||||||
```
|
|
||||||
...weathercheck(上海 浦东)...
|
|
||||||
|
|
||||||
```output
|
|
||||||
预报时间: 1小时后
|
|
||||||
具体时间: 今天 18:00
|
|
||||||
温度: 24°C
|
|
||||||
天气: 多云
|
|
||||||
风向: 西南风
|
|
||||||
风速: 7级
|
|
||||||
湿度: 88%
|
|
||||||
降水概率: 16%
|
|
||||||
|
|
||||||
Answer: 上海浦东一小时后的天气是多云。
|
|
||||||
|
|
||||||
现在,这是我的问题:
|
|
||||||
|
|
||||||
问题: {question}
|
|
||||||
"""
|
"""
|
||||||
PROMPT = PromptTemplate(
|
更简单的单参数输入工具实现,用于查询现在天气的情况
|
||||||
input_variables=["question"],
|
"""
|
||||||
template=_PROMPT_TEMPLATE,
|
from pydantic import BaseModel, Field
|
||||||
)
|
import requests
|
||||||
|
from configs.kb_config import SENIVERSE_API_KEY
|
||||||
|
|
||||||
|
|
||||||
def get_city_info(location, adm, key):
|
def weather(location: str, api_key: str):
|
||||||
base_url = 'https://geoapi.qweather.com/v2/city/lookup?'
|
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
|
||||||
params = {'location': location, 'adm': adm, 'key': key}
|
response = requests.get(url)
|
||||||
response = requests.get(base_url, params=params)
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
return data
|
weather = {
|
||||||
|
"temperature": data["results"][0]["now"]["temperature"],
|
||||||
|
"description": data["results"][0]["now"]["text"],
|
||||||
|
}
|
||||||
|
return weather
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to retrieve weather: {response.status_code}")
|
||||||
|
|
||||||
|
|
||||||
def format_weather_data(data, place):
|
def weathercheck(location: str):
|
||||||
hourly_forecast = data['hourly']
|
return weather(location, SENIVERSE_API_KEY)
|
||||||
formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n"
|
|
||||||
for forecast in hourly_forecast:
|
|
||||||
# 将预报时间转换为datetime对象
|
|
||||||
forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z')
|
|
||||||
# 获取预报时间的时区
|
|
||||||
forecast_tz = forecast_time.tzinfo
|
|
||||||
# 获取当前时间(使用预报时间的时区)
|
|
||||||
now = datetime.now(forecast_tz)
|
|
||||||
# 计算预报日期与当前日期的差值
|
|
||||||
days_diff = (forecast_time.date() - now.date()).days
|
|
||||||
if days_diff == 0:
|
|
||||||
forecast_date_str = '今天'
|
|
||||||
elif days_diff == 1:
|
|
||||||
forecast_date_str = '明天'
|
|
||||||
elif days_diff == 2:
|
|
||||||
forecast_date_str = '后天'
|
|
||||||
else:
|
|
||||||
forecast_date_str = str(days_diff) + '天后'
|
|
||||||
forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M')
|
|
||||||
# 计算预报时间与当前时间的差值
|
|
||||||
time_diff = forecast_time - now
|
|
||||||
# 将差值转换为小时
|
|
||||||
hours_diff = time_diff.total_seconds() // 3600
|
|
||||||
if hours_diff < 1:
|
|
||||||
hours_diff_str = '1小时后'
|
|
||||||
elif hours_diff >= 24:
|
|
||||||
# 如果超过24小时,转换为天数
|
|
||||||
days_diff = hours_diff // 24
|
|
||||||
hours_diff_str = str(int(days_diff)) + '天'
|
|
||||||
else:
|
|
||||||
hours_diff_str = str(int(hours_diff)) + '小时'
|
|
||||||
# 将预报时间和当前时间的差值添加到输出中
|
|
||||||
formatted_data += '预报时间: ' + forecast_time_str + ' 距离现在有: ' + hours_diff_str + '\n'
|
|
||||||
formatted_data += '温度: ' + forecast['temp'] + '°C\n'
|
|
||||||
formatted_data += '天气: ' + forecast['text'] + '\n'
|
|
||||||
formatted_data += '风向: ' + forecast['windDir'] + '\n'
|
|
||||||
formatted_data += '风速: ' + forecast['windSpeed'] + '级\n'
|
|
||||||
formatted_data += '湿度: ' + forecast['humidity'] + '%\n'
|
|
||||||
formatted_data += '降水概率: ' + forecast['pop'] + '%\n'
|
|
||||||
# formatted_data += '降水量: ' + forecast['precip'] + 'mm\n'
|
|
||||||
formatted_data += '\n'
|
|
||||||
return formatted_data
|
|
||||||
|
|
||||||
|
|
||||||
def get_weather(key, location_id, place):
|
class WeatherInput(BaseModel):
|
||||||
url = "https://devapi.qweather.com/v7/weather/24h?"
|
location: str = Field(description="City name,include city and county")
|
||||||
params = {
|
|
||||||
'location': location_id,
|
|
||||||
'key': key,
|
|
||||||
}
|
|
||||||
response = requests.get(url, params=params)
|
|
||||||
data = response.json()
|
|
||||||
return format_weather_data(data, place)
|
|
||||||
|
|
||||||
|
|
||||||
def split_query(query):
|
|
||||||
parts = query.split()
|
|
||||||
adm = parts[0]
|
|
||||||
if len(parts) == 1:
|
|
||||||
return adm, adm
|
|
||||||
location = parts[1] if parts[1] != 'None' else adm
|
|
||||||
return location, adm
|
|
||||||
|
|
||||||
|
|
||||||
def weather(query):
|
|
||||||
location, adm = split_query(query)
|
|
||||||
key = KEY
|
|
||||||
if key == "":
|
|
||||||
return "请先在代码中填入和风天气API Key"
|
|
||||||
try:
|
|
||||||
city_info = get_city_info(location=location, adm=adm, key=key)
|
|
||||||
location_id = city_info['location'][0]['id']
|
|
||||||
place = adm + "市" + location + "区"
|
|
||||||
|
|
||||||
weather_data = get_weather(key=key, location_id=location_id, place=place)
|
|
||||||
return weather_data + "以上是查询到的天气信息,请你查收\n"
|
|
||||||
except KeyError:
|
|
||||||
try:
|
|
||||||
city_info = get_city_info(location=adm, adm=adm, key=key)
|
|
||||||
location_id = city_info['location'][0]['id']
|
|
||||||
place = adm + "市"
|
|
||||||
weather_data = get_weather(key=key, location_id=location_id, place=place)
|
|
||||||
return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n"
|
|
||||||
except KeyError:
|
|
||||||
return "输入的地区不存在,无法提供天气预报"
|
|
||||||
|
|
||||||
|
|
||||||
class LLMWeatherChain(Chain):
|
|
||||||
llm_chain: LLMChain
|
|
||||||
llm: Optional[BaseLanguageModel] = None
|
|
||||||
"""[Deprecated] LLM wrapper to use."""
|
|
||||||
prompt: BasePromptTemplate = PROMPT
|
|
||||||
"""[Deprecated] Prompt to use to translate to python if necessary."""
|
|
||||||
input_key: str = "question" #: :meta private:
|
|
||||||
output_key: str = "answer" #: :meta private:
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
||||||
if "llm" in values:
|
|
||||||
warnings.warn(
|
|
||||||
"Directly instantiating an LLMWeatherChain with an llm is deprecated. "
|
|
||||||
"Please instantiate with llm_chain argument or using the from_llm "
|
|
||||||
"class method."
|
|
||||||
)
|
|
||||||
if "llm_chain" not in values and values["llm"] is not None:
|
|
||||||
prompt = values.get("prompt", PROMPT)
|
|
||||||
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Expect input key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Expect output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
def _evaluate_expression(self, expression: str) -> str:
|
|
||||||
try:
|
|
||||||
output = weather(expression)
|
|
||||||
except Exception as e:
|
|
||||||
output = "输入的信息有误,请再次尝试"
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _process_llm_result(
|
|
||||||
self, llm_output: str, run_manager: CallbackManagerForChainRun
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
|
|
||||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
|
||||||
|
|
||||||
llm_output = llm_output.strip()
|
|
||||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
|
||||||
if text_match:
|
|
||||||
expression = text_match.group(1)
|
|
||||||
output = self._evaluate_expression(expression)
|
|
||||||
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
|
||||||
run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
|
||||||
answer = "Answer: " + output
|
|
||||||
elif llm_output.startswith("Answer:"):
|
|
||||||
answer = llm_output
|
|
||||||
elif "Answer:" in llm_output:
|
|
||||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
|
||||||
else:
|
|
||||||
return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"}
|
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
async def _aprocess_llm_result(
|
|
||||||
self,
|
|
||||||
llm_output: str,
|
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
|
||||||
llm_output = llm_output.strip()
|
|
||||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
|
||||||
|
|
||||||
if text_match:
|
|
||||||
expression = text_match.group(1)
|
|
||||||
output = self._evaluate_expression(expression)
|
|
||||||
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
|
||||||
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
|
||||||
answer = "Answer: " + output
|
|
||||||
elif llm_output.startswith("Answer:"):
|
|
||||||
answer = llm_output
|
|
||||||
elif "Answer:" in llm_output:
|
|
||||||
answer = "Answer: " + llm_output.split("Answer:")[-1]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown format from LLM: {llm_output}")
|
|
||||||
return {self.output_key: answer}
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, str],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
_run_manager.on_text(inputs[self.input_key])
|
|
||||||
llm_output = self.llm_chain.predict(
|
|
||||||
question=inputs[self.input_key],
|
|
||||||
stop=["```output"],
|
|
||||||
callbacks=_run_manager.get_child(),
|
|
||||||
)
|
|
||||||
return self._process_llm_result(llm_output, _run_manager)
|
|
||||||
|
|
||||||
async def _acall(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, str],
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
||||||
await _run_manager.on_text(inputs[self.input_key])
|
|
||||||
llm_output = await self.llm_chain.apredict(
|
|
||||||
question=inputs[self.input_key],
|
|
||||||
stop=["```output"],
|
|
||||||
callbacks=_run_manager.get_child(),
|
|
||||||
)
|
|
||||||
return await self._aprocess_llm_result(llm_output, _run_manager)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "llm_weather_chain"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
prompt: BasePromptTemplate = PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMWeatherChain:
|
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
||||||
return cls(llm_chain=llm_chain, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def weathercheck(query: str):
|
|
||||||
model = model_container.MODEL
|
|
||||||
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
|
|
||||||
ans = llm_weather.run(query)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
class WhetherSchema(BaseModel):
|
|
||||||
location: str = Field(description="应该是一个地区的名称,用空格隔开,例如:上海 浦东,如果没有区的信息,可以只输入上海")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
result = weathercheck("苏州姑苏区今晚热不热?")
|
|
||||||
|
|||||||
@ -1,10 +0,0 @@
|
|||||||
name: weather_check
|
|
||||||
description: Use Weather API to get weather information
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: City name,include city and county,like "厦门市思明区"
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
name: wolfram
|
|
||||||
description: Useful for when you need to calculate difficult math formulas
|
|
||||||
parameters:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
query:
|
|
||||||
type: string
|
|
||||||
description: The formula to be calculated
|
|
||||||
required:
|
|
||||||
- query
|
|
||||||
@ -1,8 +1,6 @@
|
|||||||
from langchain.tools import Tool
|
from langchain.tools import Tool
|
||||||
from server.agent.tools import *
|
from server.agent.tools import *
|
||||||
|
|
||||||
## 请注意,如果你是为了使用AgentLM,在这里,你应该使用英文版本。
|
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=calculate,
|
func=calculate,
|
||||||
@ -20,7 +18,7 @@ tools = [
|
|||||||
func=weathercheck,
|
func=weathercheck,
|
||||||
name="weather_check",
|
name="weather_check",
|
||||||
description="",
|
description="",
|
||||||
args_schema=WhetherSchema,
|
args_schema=WeatherInput,
|
||||||
),
|
),
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=shell,
|
func=shell,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Union, Optional
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|||||||
@ -1,22 +1,23 @@
|
|||||||
from langchain.memory import ConversationBufferWindowMemory
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
|
||||||
from server.agent.tools_select import tools, tool_names
|
|
||||||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
|
||||||
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
|
||||||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
||||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from typing import AsyncIterable, Optional
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
import asyncio
|
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
||||||
from typing import List
|
from typing import AsyncIterable, Optional, List
|
||||||
from server.chat.utils import History
|
|
||||||
import json
|
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||||
from server.agent import model_container
|
|
||||||
from server.knowledge_base.kb_service.base import get_kb_details
|
from server.knowledge_base.kb_service.base import get_kb_details
|
||||||
|
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
||||||
|
from server.agent.tools_select import tools, tool_names
|
||||||
|
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
||||||
|
from server.chat.utils import History
|
||||||
|
from server.agent import model_container
|
||||||
|
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
||||||
|
|
||||||
|
|
||||||
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
@ -33,7 +34,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
prompt_name: str = Body("default",
|
prompt_name: str = Body("default",
|
||||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
|
||||||
):
|
):
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
@ -55,12 +55,10 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
|
|
||||||
## 传入全局变量来实现agent调用
|
|
||||||
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
||||||
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
||||||
|
|
||||||
if Agent_MODEL:
|
if Agent_MODEL:
|
||||||
## 如果有指定使用Agent模型来完成任务
|
|
||||||
model_agent = get_ChatOpenAI(
|
model_agent = get_ChatOpenAI(
|
||||||
model_name=Agent_MODEL,
|
model_name=Agent_MODEL,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@ -79,23 +77,17 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
)
|
)
|
||||||
output_parser = CustomOutputParser()
|
output_parser = CustomOutputParser()
|
||||||
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
|
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
|
||||||
# 把history转成agent的memory
|
|
||||||
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
|
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
|
||||||
for message in history:
|
for message in history:
|
||||||
# 检查消息的角色
|
|
||||||
if message.role == 'user':
|
if message.role == 'user':
|
||||||
# 添加用户消息
|
|
||||||
memory.chat_memory.add_user_message(message.content)
|
memory.chat_memory.add_user_message(message.content)
|
||||||
else:
|
else:
|
||||||
# 添加AI消息
|
|
||||||
memory.chat_memory.add_ai_message(message.content)
|
memory.chat_memory.add_ai_message(message.content)
|
||||||
|
if "chatglm3" in model_container.MODEL.model_name or "zhipu-api" in model_container.MODEL.model_name:
|
||||||
if "chatglm3" in model_container.MODEL.model_name:
|
|
||||||
agent_executor = initialize_glm3_agent(
|
agent_executor = initialize_glm3_agent(
|
||||||
llm=model,
|
llm=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
callback_manager=None,
|
callback_manager=None,
|
||||||
# Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent.
|
|
||||||
prompt=prompt_template,
|
prompt=prompt_template,
|
||||||
input_variables=["input", "intermediate_steps", "history"],
|
input_variables=["input", "intermediate_steps", "history"],
|
||||||
memory=memory,
|
memory=memory,
|
||||||
@ -155,7 +147,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
answer = ""
|
answer = ""
|
||||||
final_answer = ""
|
final_answer = ""
|
||||||
async for chunk in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
# Use server-sent-events to stream the response
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||||
continue
|
continue
|
||||||
@ -181,7 +172,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||||||
await task
|
await task
|
||||||
|
|
||||||
return EventSourceResponse(agent_chat_iterator(query=query,
|
return EventSourceResponse(agent_chat_iterator(query=query,
|
||||||
history=history,
|
history=history,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_name=prompt_name),
|
prompt_name=prompt_name),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,23 +1,23 @@
|
|||||||
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
from langchain.utilities.bing_search import BingSearchAPIWrapper
|
||||||
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
||||||
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
||||||
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
|
||||||
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
|
||||||
from fastapi import Body
|
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
from fastapi.concurrency import run_in_threadpool
|
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
|
||||||
from server.utils import BaseResponse, get_prompt_template
|
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
|
||||||
import asyncio
|
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from typing import List, Optional, Dict
|
|
||||||
from server.chat.utils import History
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from fastapi import Body
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
from server.utils import wrap_done, get_ChatOpenAI
|
||||||
|
from server.utils import BaseResponse, get_prompt_template
|
||||||
|
from server.chat.utils import History
|
||||||
|
from typing import AsyncIterable
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from typing import List, Optional, Dict
|
||||||
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
|
|
||||||
@ -38,11 +38,11 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def metaphor_search(
|
def metaphor_search(
|
||||||
text: str,
|
text: str,
|
||||||
result_len: int = SEARCH_ENGINE_TOP_K,
|
result_len: int = SEARCH_ENGINE_TOP_K,
|
||||||
split_result: bool = False,
|
split_result: bool = False,
|
||||||
chunk_size: int = 500,
|
chunk_size: int = 500,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
from metaphor_python import Metaphor
|
from metaphor_python import Metaphor
|
||||||
|
|
||||||
@ -58,13 +58,13 @@ def metaphor_search(
|
|||||||
# metaphor 返回的内容都是长文本,需要分词再检索
|
# metaphor 返回的内容都是长文本,需要分词再检索
|
||||||
if split_result:
|
if split_result:
|
||||||
docs = [Document(page_content=x.extract,
|
docs = [Document(page_content=x.extract,
|
||||||
metadata={"link": x.url, "title": x.title})
|
metadata={"link": x.url, "title": x.title})
|
||||||
for x in contents]
|
for x in contents]
|
||||||
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap)
|
chunk_overlap=chunk_overlap)
|
||||||
splitted_docs = text_splitter.split_documents(docs)
|
splitted_docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
# 将切分好的文档放入临时向量库,重新筛选出TOP_K个文档
|
||||||
if len(splitted_docs) > result_len:
|
if len(splitted_docs) > result_len:
|
||||||
normal = NormalizedLevenshtein()
|
normal = NormalizedLevenshtein()
|
||||||
@ -74,13 +74,13 @@ def metaphor_search(
|
|||||||
splitted_docs = splitted_docs[:result_len]
|
splitted_docs = splitted_docs[:result_len]
|
||||||
|
|
||||||
docs = [{"snippet": x.page_content,
|
docs = [{"snippet": x.page_content,
|
||||||
"link": x.metadata["link"],
|
"link": x.metadata["link"],
|
||||||
"title": x.metadata["title"]}
|
"title": x.metadata["title"]}
|
||||||
for x in splitted_docs]
|
for x in splitted_docs]
|
||||||
else:
|
else:
|
||||||
docs = [{"snippet": x.extract,
|
docs = [{"snippet": x.extract,
|
||||||
"link": x.url,
|
"link": x.url,
|
||||||
"title": x.title}
|
"title": x.title}
|
||||||
for x in contents]
|
for x in contents]
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
@ -113,25 +113,27 @@ async def lookup_search_engine(
|
|||||||
docs = search_result2docs(results)
|
docs = search_result2docs(results)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
examples=[[
|
examples=[[
|
||||||
{"role": "user",
|
{"role": "user",
|
||||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
{"role": "assistant",
|
{"role": "assistant",
|
||||||
"content": "虎头虎脑"}]]
|
"content": "虎头虎脑"}]]
|
||||||
),
|
),
|
||||||
stream: bool = Body(False, description="流式输出"),
|
stream: bool = Body(False, description="流式输出"),
|
||||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
max_tokens: Optional[int] = Body(None,
|
||||||
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||||
split_result: bool = Body(False, description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
|
prompt_name: str = Body("default",
|
||||||
):
|
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||||
|
split_result: bool = Body(False,
|
||||||
|
description="是否对搜索结果进行拆分(主要用于metaphor搜索引擎)")
|
||||||
|
):
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
|
|
||||||
@ -198,9 +200,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||||||
await task
|
await task
|
||||||
|
|
||||||
return EventSourceResponse(search_engine_chat_iterator(query=query,
|
return EventSourceResponse(search_engine_chat_iterator(query=query,
|
||||||
search_engine_name=search_engine_name,
|
search_engine_name=search_engine_name,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
history=history,
|
history=history,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_name=prompt_name),
|
prompt_name=prompt_name),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -83,7 +83,7 @@ def add_file_to_db(session,
|
|||||||
kb_file: KnowledgeFile,
|
kb_file: KnowledgeFile,
|
||||||
docs_count: int = 0,
|
docs_count: int = 0,
|
||||||
custom_docs: bool = False,
|
custom_docs: bool = False,
|
||||||
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...]
|
doc_infos: List[Dict] = [], # 形式:[{"id": str, "metadata": dict}, ...]
|
||||||
):
|
):
|
||||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||||
if kb:
|
if kb:
|
||||||
|
|||||||
@ -16,7 +16,6 @@ def embed_texts(
|
|||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
'''
|
'''
|
||||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||||
TODO: 也许需要加入缓存机制,减少 token 消耗
|
|
||||||
'''
|
'''
|
||||||
try:
|
try:
|
||||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||||
|
|||||||
@ -13,9 +13,9 @@ def list_kbs():
|
|||||||
|
|
||||||
|
|
||||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
vector_store_type: str = Body("faiss"),
|
vector_store_type: str = Body("faiss"),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
# Create selected knowledge base
|
# Create selected knowledge base
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
@ -39,8 +39,8 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||||||
|
|
||||||
|
|
||||||
def delete_kb(
|
def delete_kb(
|
||||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
# Delete selected knowledge base
|
# Delete selected knowledge base
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return BaseResponse(code=403, msg="Don't attack me")
|
return BaseResponse(code=403, msg="Don't attack me")
|
||||||
|
|||||||
@ -55,8 +55,6 @@ class _FaissPool(CachePool):
|
|||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
embed_device: str = embedding_device(),
|
embed_device: str = embedding_device(),
|
||||||
) -> FAISS:
|
) -> FAISS:
|
||||||
# TODO: 整个Embeddings加载逻辑有些混乱,待清理
|
|
||||||
# create an empty vector store
|
|
||||||
embeddings = EmbeddingsFunAdapter(embed_model)
|
embeddings = EmbeddingsFunAdapter(embed_model)
|
||||||
doc = Document(page_content="init", metadata={})
|
doc = Document(page_content="init", metadata={})
|
||||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||||
|
|||||||
@ -95,7 +95,6 @@ def _save_files_in_thread(files: List[UploadFile],
|
|||||||
and not override
|
and not override
|
||||||
and os.path.getsize(file_path) == len(file_content)
|
and os.path.getsize(file_path) == len(file_content)
|
||||||
):
|
):
|
||||||
# TODO: filesize 不同后的处理
|
|
||||||
file_status = f"文件 {filename} 已存在。"
|
file_status = f"文件 {filename} 已存在。"
|
||||||
logger.warn(file_status)
|
logger.warn(file_status)
|
||||||
return dict(code=404, msg=file_status, data=data)
|
return dict(code=404, msg=file_status, data=data)
|
||||||
@ -116,7 +115,6 @@ def _save_files_in_thread(files: List[UploadFile],
|
|||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
|
|
||||||
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||||
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||||
# override: bool = Form(False, description="覆盖已有文件"),
|
# override: bool = Form(False, description="覆盖已有文件"),
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from server.knowledge_base.utils import (
|
|||||||
list_kbs_from_folder, list_files_from_folder,
|
list_kbs_from_folder, list_files_from_folder,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import List, Union, Dict, Optional
|
from typing import List, Union, Dict, Optional, Tuple
|
||||||
|
|
||||||
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
|
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
|
||||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
@ -191,7 +191,6 @@ class KBService(ABC):
|
|||||||
'''
|
'''
|
||||||
传入参数为: {doc_id: Document, ...}
|
传入参数为: {doc_id: Document, ...}
|
||||||
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
||||||
TODO:是否要支持新增 docs ?
|
|
||||||
'''
|
'''
|
||||||
self.del_doc_by_ids(list(docs.keys()))
|
self.del_doc_by_ids(list(docs.keys()))
|
||||||
docs = []
|
docs = []
|
||||||
@ -261,7 +260,7 @@ class KBService(ABC):
|
|||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
) -> List[Document]:
|
) -> List[Tuple[Document, float]]:
|
||||||
"""
|
"""
|
||||||
搜索知识库子类实自己逻辑
|
搜索知识库子类实自己逻辑
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from langchain.schema import Document
|
|||||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||||
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
|
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
from server.utils import load_local_embeddings
|
from server.utils import load_local_embeddings
|
||||||
from elasticsearch import Elasticsearch,BadRequestError
|
from elasticsearch import Elasticsearch,BadRequestError
|
||||||
from configs import logger
|
from configs import logger
|
||||||
@ -15,7 +16,7 @@ class ESKBService(KBService):
|
|||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
self.kb_path = self.get_kb_path(self.kb_name)
|
self.kb_path = self.get_kb_path(self.kb_name)
|
||||||
self.index_name = self.kb_path.split("/")[-1]
|
self.index_name = os.path.split(self.kb_path)[-1]
|
||||||
self.IP = kbs_config[self.vs_type()]['host']
|
self.IP = kbs_config[self.vs_type()]['host']
|
||||||
self.PORT = kbs_config[self.vs_type()]['port']
|
self.PORT = kbs_config[self.vs_type()]['port']
|
||||||
self.user = kbs_config[self.vs_type()].get("user",'')
|
self.user = kbs_config[self.vs_type()].get("user",'')
|
||||||
@ -38,7 +39,16 @@ class ESKBService(KBService):
|
|||||||
raise e
|
raise e
|
||||||
try:
|
try:
|
||||||
# 首先尝试通过es_client_python创建
|
# 首先尝试通过es_client_python创建
|
||||||
self.es_client_python.indices.create(index=self.index_name)
|
mappings = {
|
||||||
|
"properties": {
|
||||||
|
"dense_vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": self.dims_length,
|
||||||
|
"index": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.es_client_python.indices.create(index=self.index_name, mappings=mappings)
|
||||||
except BadRequestError as e:
|
except BadRequestError as e:
|
||||||
logger.error("创建索引失败,重新")
|
logger.error("创建索引失败,重新")
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
@ -80,9 +90,9 @@ class ESKBService(KBService):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("创建索引失败...")
|
logger.error("创建索引失败...")
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
# raise e
|
# raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kb_path(knowledge_base_name: str):
|
def get_kb_path(knowledge_base_name: str):
|
||||||
@ -220,7 +230,12 @@ class ESKBService(KBService):
|
|||||||
shutil.rmtree(self.kb_path)
|
shutil.rmtree(self.kb_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
esKBService = ESKBService("test")
|
||||||
|
#esKBService.clear_vs()
|
||||||
|
#esKBService.create_kb()
|
||||||
|
esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test"))
|
||||||
|
print(esKBService.search_docs("如何启动api服务"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafe
|
|||||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||||
from server.utils import torch_gc
|
from server.utils import torch_gc
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
class FaissKBService(KBService):
|
class FaissKBService(KBService):
|
||||||
@ -61,7 +61,7 @@ class FaissKBService(KBService):
|
|||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
) -> List[Document]:
|
) -> List[Tuple[Document, float]]:
|
||||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||||
embeddings = embed_func.embed_query(query)
|
embeddings = embed_func.embed_query(query)
|
||||||
with self.load_vector_store().acquire() as vs:
|
with self.load_vector_store().acquire() as vs:
|
||||||
|
|||||||
@ -18,13 +18,10 @@ class MilvusKBService(KBService):
|
|||||||
from pymilvus import Collection
|
from pymilvus import Collection
|
||||||
return Collection(milvus_name)
|
return Collection(milvus_name)
|
||||||
|
|
||||||
# def save_vector_store(self):
|
|
||||||
# if self.milvus.col:
|
|
||||||
# self.milvus.col.flush()
|
|
||||||
|
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
result = []
|
result = []
|
||||||
if self.milvus.col:
|
if self.milvus.col:
|
||||||
|
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
|
||||||
data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
text = data.pop("text")
|
text = data.pop("text")
|
||||||
@ -73,7 +70,6 @@ class MilvusKBService(KBService):
|
|||||||
return score_threshold_process(score_threshold, top_k, docs)
|
return score_threshold_process(score_threshold, top_k, docs)
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
# TODO: workaround for bug #10492 in langchain
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
for k, v in doc.metadata.items():
|
for k, v in doc.metadata.items():
|
||||||
doc.metadata[k] = str(v)
|
doc.metadata[k] = str(v)
|
||||||
|
|||||||
@ -11,25 +11,27 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em
|
|||||||
score_threshold_process
|
score_threshold_process
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
import shutil
|
import shutil
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy.engine.base import Engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
class PGKBService(KBService):
|
class PGKBService(KBService):
|
||||||
pg_vector: PGVector
|
engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
|
||||||
|
|
||||||
def _load_pg_vector(self):
|
def _load_pg_vector(self):
|
||||||
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||||
collection_name=self.kb_name,
|
collection_name=self.kb_name,
|
||||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||||
|
connection=PGKBService.engine,
|
||||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||||
|
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
with self.pg_vector.connect() as connect:
|
with Session(PGKBService.engine) as session:
|
||||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
|
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
|
||||||
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
||||||
connect.execute(stmt, parameters={'ids': ids}).fetchall()]
|
session.execute(stmt, {'ids': ids}).fetchall()]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# TODO:
|
|
||||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||||
return super().del_doc_by_ids(ids)
|
return super().del_doc_by_ids(ids)
|
||||||
|
|
||||||
@ -43,8 +45,8 @@ class PGKBService(KBService):
|
|||||||
return SupportedVSType.PG
|
return SupportedVSType.PG
|
||||||
|
|
||||||
def do_drop_kb(self):
|
def do_drop_kb(self):
|
||||||
with self.pg_vector.connect() as connect:
|
with Session(PGKBService.engine) as session:
|
||||||
connect.execute(text(f'''
|
session.execute(text(f'''
|
||||||
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
|
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
|
||||||
DELETE FROM langchain_pg_embedding
|
DELETE FROM langchain_pg_embedding
|
||||||
WHERE collection_id IN (
|
WHERE collection_id IN (
|
||||||
@ -53,11 +55,10 @@ class PGKBService(KBService):
|
|||||||
-- 删除 langchain_pg_collection 表中 记录
|
-- 删除 langchain_pg_collection 表中 记录
|
||||||
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
|
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
|
||||||
'''))
|
'''))
|
||||||
connect.commit()
|
session.commit()
|
||||||
shutil.rmtree(self.kb_path)
|
shutil.rmtree(self.kb_path)
|
||||||
|
|
||||||
def do_search(self, query: str, top_k: int, score_threshold: float):
|
def do_search(self, query: str, top_k: int, score_threshold: float):
|
||||||
self._load_pg_vector()
|
|
||||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||||
embeddings = embed_func.embed_query(query)
|
embeddings = embed_func.embed_query(query)
|
||||||
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||||
@ -69,13 +70,13 @@ class PGKBService(KBService):
|
|||||||
return doc_infos
|
return doc_infos
|
||||||
|
|
||||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||||
with self.pg_vector.connect() as connect:
|
with Session(PGKBService.engine) as session:
|
||||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||||
connect.execute(
|
session.execute(
|
||||||
text(
|
text(
|
||||||
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
|
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
|
||||||
"filepath", filepath)))
|
"filepath", filepath)))
|
||||||
connect.commit()
|
session.commit()
|
||||||
|
|
||||||
def do_clear_vs(self):
|
def do_clear_vs(self):
|
||||||
self.pg_vector.delete_collection()
|
self.pg_vector.delete_collection()
|
||||||
|
|||||||
@ -16,13 +16,10 @@ class ZillizKBService(KBService):
|
|||||||
from pymilvus import Collection
|
from pymilvus import Collection
|
||||||
return Collection(zilliz_name)
|
return Collection(zilliz_name)
|
||||||
|
|
||||||
# def save_vector_store(self):
|
|
||||||
# if self.zilliz.col:
|
|
||||||
# self.zilliz.col.flush()
|
|
||||||
|
|
||||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||||
result = []
|
result = []
|
||||||
if self.zilliz.col:
|
if self.zilliz.col:
|
||||||
|
# ids = [int(id) for id in ids] # for zilliz if needed #pr 2725
|
||||||
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
text = data.pop("text")
|
text = data.pop("text")
|
||||||
@ -50,8 +47,7 @@ class ZillizKBService(KBService):
|
|||||||
def _load_zilliz(self):
|
def _load_zilliz(self):
|
||||||
zilliz_args = kbs_config.get("zilliz")
|
zilliz_args = kbs_config.get("zilliz")
|
||||||
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||||
collection_name=self.kb_name, connection_args=zilliz_args)
|
collection_name=self.kb_name, connection_args=zilliz_args)
|
||||||
|
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
self._load_zilliz()
|
self._load_zilliz()
|
||||||
@ -95,9 +91,7 @@ class ZillizKBService(KBService):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
from server.db.base import Base, engine
|
from server.db.base import Base, engine
|
||||||
|
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
zillizService = ZillizKBService("test")
|
zillizService = ZillizKBService("test")
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from server.db.repository.knowledge_metadata_repository import add_summary_to_db
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
# TODO 暂不考虑文件更新,需要重新删除相关文档,再重新添加
|
|
||||||
class KBSummaryService(ABC):
|
class KBSummaryService(ABC):
|
||||||
kb_name: str
|
kb_name: str
|
||||||
embed_model: str
|
embed_model: str
|
||||||
|
|||||||
@ -112,12 +112,6 @@ class SummaryAdapter:
|
|||||||
docs: List[DocumentWithVSId] = []) -> List[Document]:
|
docs: List[DocumentWithVSId] = []) -> List[Document]:
|
||||||
|
|
||||||
logger.info("start summary")
|
logger.info("start summary")
|
||||||
# TODO 暂不处理文档中涉及语义重复、上下文缺失、document was longer than the context length 的问题
|
|
||||||
# merge_docs = self._drop_overlap(docs)
|
|
||||||
# # 将merge_docs中的句子合并成一个文档
|
|
||||||
# text = self._join_docs(merge_docs)
|
|
||||||
# 根据段落于句子的分隔符,将文档分成chunk,每个chunk长度小于token_max长度
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
这个过程分成两个部分:
|
这个过程分成两个部分:
|
||||||
1. 对每个文档进行处理,得到每个文档的摘要
|
1. 对每个文档进行处理,得到每个文档的摘要
|
||||||
|
|||||||
@ -91,9 +91,14 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
|||||||
"JSONLoader": [".json"],
|
"JSONLoader": [".json"],
|
||||||
"JSONLinesLoader": [".jsonl"],
|
"JSONLinesLoader": [".jsonl"],
|
||||||
"CSVLoader": [".csv"],
|
"CSVLoader": [".csv"],
|
||||||
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持
|
# "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
|
||||||
"RapidOCRPDFLoader": [".pdf"],
|
"RapidOCRPDFLoader": [".pdf"],
|
||||||
|
"RapidOCRDocLoader": ['.docx', '.doc'],
|
||||||
|
"RapidOCRPPTLoader": ['.ppt', '.pptx', ],
|
||||||
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||||
|
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||||
|
'.rtf', '.txt', '.xml',
|
||||||
|
'.epub', '.odt','.tsv'],
|
||||||
"UnstructuredEmailLoader": ['.eml', '.msg'],
|
"UnstructuredEmailLoader": ['.eml', '.msg'],
|
||||||
"UnstructuredEPubLoader": ['.epub'],
|
"UnstructuredEPubLoader": ['.epub'],
|
||||||
"UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'],
|
"UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'],
|
||||||
@ -109,7 +114,6 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
|
|||||||
"UnstructuredXMLLoader": ['.xml'],
|
"UnstructuredXMLLoader": ['.xml'],
|
||||||
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
|
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
|
||||||
"EverNoteLoader": ['.enex'],
|
"EverNoteLoader": ['.enex'],
|
||||||
"UnstructuredFileLoader": ['.txt'],
|
|
||||||
}
|
}
|
||||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||||
|
|
||||||
@ -141,15 +145,14 @@ def get_LoaderClass(file_extension):
|
|||||||
if file_extension in extensions:
|
if file_extension in extensions:
|
||||||
return LoaderClass
|
return LoaderClass
|
||||||
|
|
||||||
|
|
||||||
# 把一些向量化共用逻辑从KnowledgeFile抽取出来,等langchain支持内存文件的时候,可以将非磁盘文件向量化
|
|
||||||
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||||
'''
|
'''
|
||||||
根据loader_name和文件路径或内容返回文档加载器。
|
根据loader_name和文件路径或内容返回文档加载器。
|
||||||
'''
|
'''
|
||||||
loader_kwargs = loader_kwargs or {}
|
loader_kwargs = loader_kwargs or {}
|
||||||
try:
|
try:
|
||||||
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
|
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader",
|
||||||
|
"RapidOCRDocLoader", "RapidOCRPPTLoader"]:
|
||||||
document_loaders_module = importlib.import_module('document_loaders')
|
document_loaders_module = importlib.import_module('document_loaders')
|
||||||
else:
|
else:
|
||||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||||
@ -171,7 +174,6 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
|||||||
if encode_detect is None:
|
if encode_detect is None:
|
||||||
encode_detect = {"encoding": "utf-8"}
|
encode_detect = {"encoding": "utf-8"}
|
||||||
loader_kwargs["encoding"] = encode_detect["encoding"]
|
loader_kwargs["encoding"] = encode_detect["encoding"]
|
||||||
## TODO:支持更多的自定义CSV读取逻辑
|
|
||||||
|
|
||||||
elif loader_name == "JSONLoader":
|
elif loader_name == "JSONLoader":
|
||||||
loader_kwargs.setdefault("jq_schema", ".")
|
loader_kwargs.setdefault("jq_schema", ".")
|
||||||
@ -259,6 +261,10 @@ def make_text_splitter(
|
|||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||||
text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
|
|
||||||
|
# If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
|
||||||
|
# text_splitter._tokenizer.max_length = 37016792
|
||||||
|
# text_splitter._tokenizer.prefer_gpu()
|
||||||
return text_splitter
|
return text_splitter
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
51
server/minx_chat_openai.py
Normal file
51
server/minx_chat_openai.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Tuple
|
||||||
|
)
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
class MinxChatOpenAI:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def import_tiktoken() -> Any:
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import tiktoken python package. "
|
||||||
|
"This is needed in order to calculate get_token_ids. "
|
||||||
|
"Please install it with `pip install tiktoken`."
|
||||||
|
)
|
||||||
|
return tiktoken
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]:
|
||||||
|
tiktoken_ = MinxChatOpenAI.import_tiktoken()
|
||||||
|
if self.tiktoken_model_name is not None:
|
||||||
|
model = self.tiktoken_model_name
|
||||||
|
else:
|
||||||
|
model = self.model_name
|
||||||
|
if model == "gpt-3.5-turbo":
|
||||||
|
# gpt-3.5-turbo may change over time.
|
||||||
|
# Returning num tokens assuming gpt-3.5-turbo-0301.
|
||||||
|
model = "gpt-3.5-turbo-0301"
|
||||||
|
elif model == "gpt-4":
|
||||||
|
# gpt-4 may change over time.
|
||||||
|
# Returning num tokens assuming gpt-4-0314.
|
||||||
|
model = "gpt-4-0314"
|
||||||
|
# Returns the number of tokens used by a list of messages.
|
||||||
|
try:
|
||||||
|
encoding = tiktoken_.encoding_for_model(model)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
model = "cl100k_base"
|
||||||
|
encoding = tiktoken_.get_encoding(model)
|
||||||
|
return model, encoding
|
||||||
@ -8,3 +8,4 @@ from .qwen import QwenWorker
|
|||||||
from .baichuan import BaiChuanWorker
|
from .baichuan import BaiChuanWorker
|
||||||
from .azure import AzureWorker
|
from .azure import AzureWorker
|
||||||
from .tiangong import TianGongWorker
|
from .tiangong import TianGongWorker
|
||||||
|
from .gemini import GeminiWorker
|
||||||
@ -67,12 +67,10 @@ class AzureWorker(ApiModelWorker):
|
|||||||
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
|
self.logger.error(f"请求 Azure API 时发生错误:{resp}")
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
|
||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
return conv.Conversation(
|
return conv.Conversation(
|
||||||
name=self.model_names[0],
|
name=self.model_names[0],
|
||||||
system_message="You are a helpful, respectful and honest assistant.",
|
system_message="You are a helpful, respectful and honest assistant.",
|
||||||
|
|||||||
@ -88,12 +88,10 @@ class BaiChuanWorker(ApiModelWorker):
|
|||||||
yield data
|
yield data
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
|
||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
return conv.Conversation(
|
return conv.Conversation(
|
||||||
name=self.model_names[0],
|
name=self.model_names[0],
|
||||||
system_message="",
|
system_message="",
|
||||||
|
|||||||
@ -125,8 +125,6 @@ class ApiModelWorker(BaseModelWorker):
|
|||||||
|
|
||||||
|
|
||||||
def count_token(self, params):
|
def count_token(self, params):
|
||||||
# TODO:需要完善
|
|
||||||
# print("count token")
|
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
return {"count": len(str(prompt)), "error_code": 0}
|
return {"count": len(str(prompt)), "error_code": 0}
|
||||||
|
|
||||||
|
|||||||
@ -12,16 +12,16 @@ class FangZhouWorker(ApiModelWorker):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
model_names: List[str] = ["fangzhou-api"],
|
model_names: List[str] = ["fangzhou-api"],
|
||||||
controller_addr: str = None,
|
controller_addr: str = None,
|
||||||
worker_addr: str = None,
|
worker_addr: str = None,
|
||||||
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
|
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小
|
kwargs.setdefault("context_len", 16384)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
@ -53,15 +53,15 @@ class FangZhouWorker(ApiModelWorker):
|
|||||||
if error := resp.error:
|
if error := resp.error:
|
||||||
if error.code_n > 0:
|
if error.code_n > 0:
|
||||||
data = {
|
data = {
|
||||||
"error_code": error.code_n,
|
"error_code": error.code_n,
|
||||||
"text": error.message,
|
"text": error.message,
|
||||||
"error": {
|
"error": {
|
||||||
"message": error.message,
|
"message": error.message,
|
||||||
"type": "invalid_request_error",
|
"type": "invalid_request_error",
|
||||||
"param": None,
|
"param": None,
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
self.logger.error(f"请求方舟 API 时发生错误:{data}")
|
self.logger.error(f"请求方舟 API 时发生错误:{data}")
|
||||||
yield data
|
yield data
|
||||||
elif chunk := resp.choice.message.content:
|
elif chunk := resp.choice.message.content:
|
||||||
@ -77,7 +77,6 @@ class FangZhouWorker(ApiModelWorker):
|
|||||||
break
|
break
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
|
||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
|
|||||||
123
server/model_workers/gemini.py
Normal file
123
server/model_workers/gemini.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import sys
|
||||||
|
from fastchat.conversation import Conversation
|
||||||
|
from server.model_workers.base import *
|
||||||
|
from server.utils import get_httpx_client
|
||||||
|
from fastchat import conversation as conv
|
||||||
|
import json, httpx
|
||||||
|
from typing import List, Dict
|
||||||
|
from configs import logger, log_verbose
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiWorker(ApiModelWorker):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
controller_addr: str = None,
|
||||||
|
worker_addr: str = None,
|
||||||
|
model_names: List[str] = ["gemini-api"],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
|
kwargs.setdefault("context_len", 4096)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def create_gemini_messages(self, messages) -> json:
|
||||||
|
has_history = any(msg['role'] == 'assistant' for msg in messages)
|
||||||
|
gemini_msg = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg['role']
|
||||||
|
content = msg['content']
|
||||||
|
if role == 'system':
|
||||||
|
continue
|
||||||
|
if has_history:
|
||||||
|
if role == 'assistant':
|
||||||
|
role = "model"
|
||||||
|
transformed_msg = {"role": role, "parts": [{"text": content}]}
|
||||||
|
else:
|
||||||
|
if role == 'user':
|
||||||
|
transformed_msg = {"parts": [{"text": content}]}
|
||||||
|
|
||||||
|
gemini_msg.append(transformed_msg)
|
||||||
|
|
||||||
|
msg = dict(contents=gemini_msg)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||||
|
params.load_config(self.model_names[0])
|
||||||
|
data = self.create_gemini_messages(messages=params.messages)
|
||||||
|
generationConfig = dict(
|
||||||
|
temperature=params.temperature,
|
||||||
|
topK=1,
|
||||||
|
topP=1,
|
||||||
|
maxOutputTokens=4096,
|
||||||
|
stopSequences=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
data['generationConfig'] = generationConfig
|
||||||
|
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
if log_verbose:
|
||||||
|
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||||
|
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||||
|
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
json_string = ""
|
||||||
|
timeout = httpx.Timeout(60.0)
|
||||||
|
client = get_httpx_client(timeout=timeout)
|
||||||
|
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||||
|
for line in response.iter_lines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or "[DONE]" in line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
json_string += line
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = json.loads(json_string)
|
||||||
|
if 'candidates' in resp:
|
||||||
|
for candidate in resp['candidates']:
|
||||||
|
content = candidate.get('content', {})
|
||||||
|
parts = content.get('parts', [])
|
||||||
|
for part in parts:
|
||||||
|
if 'text' in part:
|
||||||
|
text += part['text']
|
||||||
|
yield {
|
||||||
|
"error_code": 0,
|
||||||
|
"text": text
|
||||||
|
}
|
||||||
|
print(text)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print("Failed to decode JSON:", e)
|
||||||
|
print("Invalid JSON string:", json_string)
|
||||||
|
|
||||||
|
def get_embeddings(self, params):
|
||||||
|
print("embedding")
|
||||||
|
print(params)
|
||||||
|
|
||||||
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
|
return conv.Conversation(
|
||||||
|
name=self.model_names[0],
|
||||||
|
system_message="You are a helpful, respectful and honest assistant.",
|
||||||
|
messages=[],
|
||||||
|
roles=["user", "assistant"],
|
||||||
|
sep="\n### ",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
from server.utils import MakeFastAPIOffline
|
||||||
|
from fastchat.serve.base_model_worker import app
|
||||||
|
|
||||||
|
worker = GeminiWorker(
|
||||||
|
controller_addr="http://127.0.0.1:20001",
|
||||||
|
worker_addr="http://127.0.0.1:21012",
|
||||||
|
)
|
||||||
|
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||||
|
MakeFastAPIOffline(app)
|
||||||
|
uvicorn.run(app, port=21012)
|
||||||
@ -28,7 +28,7 @@ class MiniMaxWorker(ApiModelWorker):
|
|||||||
|
|
||||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||||
role_maps = {
|
role_maps = {
|
||||||
"user": self.user_role,
|
"USER": self.user_role,
|
||||||
"assistant": self.ai_role,
|
"assistant": self.ai_role,
|
||||||
"system": "system",
|
"system": "system",
|
||||||
}
|
}
|
||||||
@ -37,7 +37,6 @@ class MiniMaxWorker(ApiModelWorker):
|
|||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||||
# 按照官网推荐,直接调用abab 5.5模型
|
# 按照官网推荐,直接调用abab 5.5模型
|
||||||
# TODO: 支持指定回复要求,支持指定用户名称、AI名称
|
|
||||||
params.load_config(self.model_names[0])
|
params.load_config(self.model_names[0])
|
||||||
|
|
||||||
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
|
url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}'
|
||||||
@ -55,7 +54,7 @@ class MiniMaxWorker(ApiModelWorker):
|
|||||||
"temperature": params.temperature,
|
"temperature": params.temperature,
|
||||||
"top_p": params.top_p,
|
"top_p": params.top_p,
|
||||||
"tokens_to_generate": params.max_tokens or 1024,
|
"tokens_to_generate": params.max_tokens or 1024,
|
||||||
# TODO: 以下参数为minimax特有,传入空值会出错。
|
# 以下参数为minimax特有,传入空值会出错。
|
||||||
# "prompt": params.system_message or self.conv.system_message,
|
# "prompt": params.system_message or self.conv.system_message,
|
||||||
# "bot_setting": [],
|
# "bot_setting": [],
|
||||||
# "role_meta": params.role_meta,
|
# "role_meta": params.role_meta,
|
||||||
@ -140,15 +139,13 @@ class MiniMaxWorker(ApiModelWorker):
|
|||||||
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
|
||||||
return data
|
return data
|
||||||
i += batch_size
|
i += batch_size
|
||||||
return {"code": 200, "data": embeddings}
|
return {"code": 200, "data": result}
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
|
||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
return conv.Conversation(
|
return conv.Conversation(
|
||||||
name=self.model_names[0],
|
name=self.model_names[0],
|
||||||
system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。",
|
system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。",
|
||||||
|
|||||||
@ -84,30 +84,6 @@ class QianFanWorker(ApiModelWorker):
|
|||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||||
params.load_config(self.model_names[0])
|
params.load_config(self.model_names[0])
|
||||||
# import qianfan
|
|
||||||
|
|
||||||
# comp = qianfan.ChatCompletion(model=params.version,
|
|
||||||
# endpoint=params.version_url,
|
|
||||||
# ak=params.api_key,
|
|
||||||
# sk=params.secret_key,)
|
|
||||||
# text = ""
|
|
||||||
# for resp in comp.do(messages=params.messages,
|
|
||||||
# temperature=params.temperature,
|
|
||||||
# top_p=params.top_p,
|
|
||||||
# stream=True):
|
|
||||||
# if resp.code == 200:
|
|
||||||
# if chunk := resp.body.get("result"):
|
|
||||||
# text += chunk
|
|
||||||
# yield {
|
|
||||||
# "error_code": 0,
|
|
||||||
# "text": text
|
|
||||||
# }
|
|
||||||
# else:
|
|
||||||
# yield {
|
|
||||||
# "error_code": resp.code,
|
|
||||||
# "text": str(resp.body),
|
|
||||||
# }
|
|
||||||
|
|
||||||
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
|
||||||
'/{model_version}?access_token={access_token}'
|
'/{model_version}?access_token={access_token}'
|
||||||
|
|
||||||
@ -190,19 +166,19 @@ class QianFanWorker(ApiModelWorker):
|
|||||||
i = 0
|
i = 0
|
||||||
batch_size = 10
|
batch_size = 10
|
||||||
while i < len(params.texts):
|
while i < len(params.texts):
|
||||||
texts = params.texts[i:i+batch_size]
|
texts = params.texts[i:i + batch_size]
|
||||||
resp = client.post(url, json={"input": texts}).json()
|
resp = client.post(url, json={"input": texts}).json()
|
||||||
if "error_code" in resp:
|
if "error_code" in resp:
|
||||||
data = {
|
data = {
|
||||||
"code": resp["error_code"],
|
"code": resp["error_code"],
|
||||||
"msg": resp["error_msg"],
|
"msg": resp["error_msg"],
|
||||||
"error": {
|
"error": {
|
||||||
"message": resp["error_msg"],
|
"message": resp["error_msg"],
|
||||||
"type": "invalid_request_error",
|
"type": "invalid_request_error",
|
||||||
"param": None,
|
"param": None,
|
||||||
"code": None,
|
"code": None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
self.logger.error(f"请求千帆 API 时发生错误:{data}")
|
||||||
return data
|
return data
|
||||||
else:
|
else:
|
||||||
@ -211,14 +187,11 @@ class QianFanWorker(ApiModelWorker):
|
|||||||
i += batch_size
|
i += batch_size
|
||||||
return {"code": 200, "data": result}
|
return {"code": 200, "data": result}
|
||||||
|
|
||||||
# TODO: qianfan支持续写模型
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
|
||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
return conv.Conversation(
|
return conv.Conversation(
|
||||||
name=self.model_names[0],
|
name=self.model_names[0],
|
||||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||||
|
|||||||
@ -100,12 +100,10 @@ class QwenWorker(ApiModelWorker):
|
|||||||
return {"code": 200, "data": result}
|
return {"code": 200, "data": result}
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
|
||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
return conv.Conversation(
|
return conv.Conversation(
|
||||||
name=self.model_names[0],
|
name=self.model_names[0],
|
||||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
||||||
|
|||||||
@ -11,16 +11,15 @@ from typing import List, Literal, Dict
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TianGongWorker(ApiModelWorker):
|
class TianGongWorker(ApiModelWorker):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
controller_addr: str = None,
|
controller_addr: str = None,
|
||||||
worker_addr: str = None,
|
worker_addr: str = None,
|
||||||
model_names: List[str] = ["tiangong-api"],
|
model_names: List[str] = ["tiangong-api"],
|
||||||
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
|
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
kwargs.setdefault("context_len", 32768)
|
kwargs.setdefault("context_len", 32768)
|
||||||
@ -34,18 +33,18 @@ class TianGongWorker(ApiModelWorker):
|
|||||||
data = {
|
data = {
|
||||||
"messages": params.messages,
|
"messages": params.messages,
|
||||||
"model": "SkyChat-MegaVerse"
|
"model": "SkyChat-MegaVerse"
|
||||||
}
|
}
|
||||||
timestamp = str(int(time.time()))
|
timestamp = str(int(time.time()))
|
||||||
sign_content = params.api_key + params.secret_key + timestamp
|
sign_content = params.api_key + params.secret_key + timestamp
|
||||||
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
|
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
|
||||||
headers={
|
headers = {
|
||||||
"app_key": params.api_key,
|
"app_key": params.api_key,
|
||||||
"timestamp": timestamp,
|
"timestamp": timestamp,
|
||||||
"sign": sign_result,
|
"sign": sign_result,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"stream": "true" # or change to "false" 不处理流式返回内容
|
"stream": "true" # or change to "false" 不处理流式返回内容
|
||||||
}
|
}
|
||||||
|
|
||||||
# 发起请求并获取响应
|
# 发起请求并获取响应
|
||||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||||
|
|
||||||
@ -56,27 +55,25 @@ class TianGongWorker(ApiModelWorker):
|
|||||||
# 处理接收到的数据
|
# 处理接收到的数据
|
||||||
# print(line.decode('utf-8'))
|
# print(line.decode('utf-8'))
|
||||||
resp = json.loads(line)
|
resp = json.loads(line)
|
||||||
if resp["code"] == 200:
|
if resp["code"] == 200:
|
||||||
text += resp['resp_data']['reply']
|
text += resp['resp_data']['reply']
|
||||||
yield {
|
yield {
|
||||||
"error_code": 0,
|
"error_code": 0,
|
||||||
"text": text
|
"text": text
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
data = {
|
data = {
|
||||||
"error_code": resp["code"],
|
"error_code": resp["code"],
|
||||||
"text": resp["code_msg"]
|
"text": resp["code_msg"]
|
||||||
}
|
}
|
||||||
self.logger.error(f"请求天工 API 时出错:{data}")
|
self.logger.error(f"请求天工 API 时出错:{data}")
|
||||||
yield data
|
yield data
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
|
||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
return conv.Conversation(
|
return conv.Conversation(
|
||||||
name=self.model_names[0],
|
name=self.model_names[0],
|
||||||
system_message="",
|
system_message="",
|
||||||
@ -85,5 +82,3 @@ class TianGongWorker(ApiModelWorker):
|
|||||||
sep="\n### ",
|
sep="\n### ",
|
||||||
stop_str="###",
|
stop_str="###",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -37,12 +37,11 @@ class XingHuoWorker(ApiModelWorker):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000,需要自行修改
|
kwargs.setdefault("context_len", 8000)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Dict:
|
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||||
# TODO: 当前每次对话都要重新连接websocket,确认是否可以保持连接
|
|
||||||
params.load_config(self.model_names[0])
|
params.load_config(self.model_names[0])
|
||||||
|
|
||||||
version_mapping = {
|
version_mapping = {
|
||||||
@ -73,12 +72,10 @@ class XingHuoWorker(ApiModelWorker):
|
|||||||
yield {"error_code": 0, "text": text}
|
yield {"error_code": 0, "text": text}
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
|
||||||
print("embedding")
|
print("embedding")
|
||||||
print(params)
|
print(params)
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
# TODO: 确认模板是否需要修改
|
|
||||||
return conv.Conversation(
|
return conv.Conversation(
|
||||||
name=self.model_names[0],
|
name=self.model_names[0],
|
||||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
||||||
|
|||||||
@ -4,93 +4,89 @@ from fastchat import conversation as conv
|
|||||||
import sys
|
import sys
|
||||||
from typing import List, Dict, Iterator, Literal
|
from typing import List, Dict, Iterator, Literal
|
||||||
from configs import logger, log_verbose
|
from configs import logger, log_verbose
|
||||||
|
import requests
|
||||||
|
import jwt
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def generate_token(apikey: str, exp_seconds: int):
|
||||||
|
try:
|
||||||
|
id, secret = apikey.split(".")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("invalid apikey", e)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"api_key": id,
|
||||||
|
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
|
||||||
|
"timestamp": int(round(time.time() * 1000)),
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(
|
||||||
|
payload,
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMWorker(ApiModelWorker):
|
class ChatGLMWorker(ApiModelWorker):
|
||||||
DEFAULT_EMBED_MODEL = "text_embedding"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
model_names: List[str] = ["zhipu-api"],
|
model_names: List[str] = ["zhipu-api"],
|
||||||
controller_addr: str = None,
|
controller_addr: str = None,
|
||||||
worker_addr: str = None,
|
worker_addr: str = None,
|
||||||
version: Literal["chatglm_turbo"] = "chatglm_turbo",
|
version: Literal["chatglm_turbo"] = "chatglm_turbo",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
|
||||||
kwargs.setdefault("context_len", 32768)
|
kwargs.setdefault("context_len", 4096)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|
||||||
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
|
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
|
||||||
# TODO: 维护request_id
|
|
||||||
import zhipuai
|
|
||||||
|
|
||||||
params.load_config(self.model_names[0])
|
params.load_config(self.model_names[0])
|
||||||
zhipuai.api_key = params.api_key
|
token = generate_token(params.api_key, 60)
|
||||||
|
headers = {
|
||||||
if log_verbose:
|
"Content-Type": "application/json",
|
||||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
"Authorization": f"Bearer {token}"
|
||||||
|
}
|
||||||
response = zhipuai.model_api.sse_invoke(
|
data = {
|
||||||
model=params.version,
|
"model": params.version,
|
||||||
prompt=params.messages,
|
"messages": params.messages,
|
||||||
temperature=params.temperature,
|
"max_tokens": params.max_tokens,
|
||||||
top_p=params.top_p,
|
"temperature": params.temperature,
|
||||||
incremental=False,
|
"stream": False
|
||||||
)
|
}
|
||||||
for e in response.events():
|
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
|
||||||
if e.event == "add":
|
response = requests.post(url, headers=headers, json=data)
|
||||||
yield {"error_code": 0, "text": e.data}
|
# for chunk in response.iter_lines():
|
||||||
elif e.event in ["error", "interrupted"]:
|
# if chunk:
|
||||||
data = {
|
# chunk_str = chunk.decode('utf-8')
|
||||||
"error_code": 500,
|
# json_start_pos = chunk_str.find('{"id"')
|
||||||
"text": e.data,
|
# if json_start_pos != -1:
|
||||||
"error": {
|
# json_str = chunk_str[json_start_pos:]
|
||||||
"message": e.data,
|
# json_data = json.loads(json_str)
|
||||||
"type": "invalid_request_error",
|
# for choice in json_data.get('choices', []):
|
||||||
"param": None,
|
# delta = choice.get('delta', {})
|
||||||
"code": None,
|
# content = delta.get('content', '')
|
||||||
}
|
# yield {"error_code": 0, "text": content}
|
||||||
}
|
ans = response.json()
|
||||||
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
content = ans["choices"][0]["message"]["content"]
|
||||||
yield data
|
yield {"error_code": 0, "text": content}
|
||||||
|
|
||||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
|
||||||
import zhipuai
|
|
||||||
|
|
||||||
params.load_config(self.model_names[0])
|
|
||||||
zhipuai.api_key = params.api_key
|
|
||||||
|
|
||||||
embeddings = []
|
|
||||||
try:
|
|
||||||
for t in params.texts:
|
|
||||||
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
|
|
||||||
if response["code"] == 200:
|
|
||||||
embeddings.append(response["data"]["embedding"])
|
|
||||||
else:
|
|
||||||
self.logger.error(f"请求智谱 API 时发生错误:{response}")
|
|
||||||
return response # dict with code & msg
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"请求智谱 API 时发生错误:{data}")
|
|
||||||
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
|
||||||
return data
|
|
||||||
|
|
||||||
return {"code": 200, "data": embeddings}
|
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# 临时解决方案,不支持embedding
|
||||||
print("embedding")
|
print("embedding")
|
||||||
# print(params)
|
print(params)
|
||||||
|
|
||||||
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
|
||||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
|
||||||
return conv.Conversation(
|
return conv.Conversation(
|
||||||
name=self.model_names[0],
|
name=self.model_names[0],
|
||||||
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
|
system_message="你是智谱AI小助手,请根据用户的提示来完成任务",
|
||||||
messages=[],
|
messages=[],
|
||||||
roles=["Human", "Assistant", "System"],
|
roles=["user", "assistant", "system"],
|
||||||
sep="\n###",
|
sep="\n###",
|
||||||
stop_str="###",
|
stop_str="###",
|
||||||
)
|
)
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@ -12,10 +12,23 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Callable,
|
||||||
|
Generator,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Union,
|
||||||
|
Tuple
|
||||||
|
)
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from server.minx_chat_openai import MinxChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
||||||
@ -23,7 +36,6 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
|||||||
await fn
|
await fn
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
# TODO: handle exception
|
|
||||||
msg = f"Caught exception: {e}"
|
msg = f"Caught exception: {e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
@ -44,7 +56,7 @@ def get_ChatOpenAI(
|
|||||||
config = get_model_worker_config(model_name)
|
config = get_model_worker_config(model_name)
|
||||||
if model_name == "openai-api":
|
if model_name == "openai-api":
|
||||||
model_name = config.get("model_name")
|
model_name = config.get("model_name")
|
||||||
|
ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
@ -153,6 +165,7 @@ class ChatMessage(BaseModel):
|
|||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
try:
|
try:
|
||||||
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# with torch.cuda.device(DEVICE):
|
# with torch.cuda.device(DEVICE):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -390,7 +403,7 @@ def fschat_controller_address() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
|
def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
|
||||||
if model := get_model_worker_config(model_name): # TODO: depends fastchat
|
if model := get_model_worker_config(model_name):
|
||||||
host = model["host"]
|
host = model["host"]
|
||||||
if host == "0.0.0.0":
|
if host == "0.0.0.0":
|
||||||
host = "127.0.0.1"
|
host = "127.0.0.1"
|
||||||
@ -435,7 +448,7 @@ def get_prompt_template(type: str, name: str) -> Optional[str]:
|
|||||||
|
|
||||||
from configs import prompt_config
|
from configs import prompt_config
|
||||||
import importlib
|
import importlib
|
||||||
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
importlib.reload(prompt_config)
|
||||||
return prompt_config.PROMPT_TEMPLATES[type].get(name)
|
return prompt_config.PROMPT_TEMPLATES[type].get(name)
|
||||||
|
|
||||||
|
|
||||||
@ -489,69 +502,36 @@ def set_httpx_config(
|
|||||||
no_proxy.append(host)
|
no_proxy.append(host)
|
||||||
os.environ["NO_PROXY"] = ",".join(no_proxy)
|
os.environ["NO_PROXY"] = ",".join(no_proxy)
|
||||||
|
|
||||||
# TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。
|
|
||||||
# patch requests to use custom proxies instead of system settings
|
|
||||||
def _get_proxies():
|
def _get_proxies():
|
||||||
return proxies
|
return proxies
|
||||||
|
|
||||||
import urllib.request
|
import urllib.request
|
||||||
urllib.request.getproxies = _get_proxies
|
urllib.request.getproxies = _get_proxies
|
||||||
|
|
||||||
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
|
||||||
|
|
||||||
|
|
||||||
def is_mps_available():
|
|
||||||
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
||||||
|
|
||||||
|
|
||||||
def is_cuda_available():
|
|
||||||
return torch.cuda.is_available()
|
|
||||||
|
|
||||||
|
|
||||||
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||||
try:
|
try:
|
||||||
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return "cuda"
|
return "cuda"
|
||||||
if is_mps_available():
|
if torch.backends.mps.is_available():
|
||||||
return "mps"
|
return "mps"
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||||
device = device or LLM_DEVICE
|
device = device or LLM_DEVICE
|
||||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
|
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
|
|
||||||
logging.warning("cuda is not available, fallback to mps")
|
|
||||||
return "mps"
|
|
||||||
if device == 'mps' and not is_mps_available() and is_cuda_available():
|
|
||||||
logging.warning("mps is not available, fallback to cuda")
|
|
||||||
return "cuda"
|
|
||||||
|
|
||||||
# auto detect device if not specified
|
|
||||||
if device not in ["cuda", "mps", "cpu", "xpu"]:
|
|
||||||
return detect_device()
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu", "xpu"]:
|
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||||
device = device or LLM_DEVICE
|
device = device or EMBEDDING_DEVICE
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
logging.warning(f"device not in ['cuda', 'mps', 'cpu','xpu'], device = {device}")
|
|
||||||
device = detect_device()
|
device = detect_device()
|
||||||
elif device == 'cuda' and not is_cuda_available() and is_mps_available():
|
|
||||||
logging.warning("cuda is not available, fallback to mps")
|
|
||||||
return "mps"
|
|
||||||
if device == 'mps' and not is_mps_available() and is_cuda_available():
|
|
||||||
logging.warning("mps is not available, fallback to cuda")
|
|
||||||
return "cuda"
|
|
||||||
|
|
||||||
# auto detect device if not specified
|
|
||||||
if device not in ["cuda", "mps", "cpu"]:
|
|
||||||
return detect_device()
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
@ -569,7 +549,7 @@ def run_in_thread_pool(
|
|||||||
thread = pool.submit(func, **kwargs)
|
thread = pool.submit(func, **kwargs)
|
||||||
tasks.append(thread)
|
tasks.append(thread)
|
||||||
|
|
||||||
for obj in as_completed(tasks): # TODO: Ctrl+c无法停止
|
for obj in as_completed(tasks):
|
||||||
yield obj.result()
|
yield obj.result()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
109
startup.py
109
startup.py
@ -6,9 +6,8 @@ import sys
|
|||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
from langchain_core._api import deprecated
|
||||||
|
|
||||||
|
|
||||||
# 设置numexpr最大线程数,默认为CPU核心数
|
|
||||||
try:
|
try:
|
||||||
import numexpr
|
import numexpr
|
||||||
|
|
||||||
@ -33,15 +32,18 @@ from configs import (
|
|||||||
HTTPX_DEFAULT_TIMEOUT,
|
HTTPX_DEFAULT_TIMEOUT,
|
||||||
)
|
)
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||||
fschat_openai_api_address, set_httpx_config, get_httpx_client,
|
fschat_openai_api_address, get_httpx_client, get_model_worker_config,
|
||||||
get_model_worker_config, get_all_model_worker_configs,
|
|
||||||
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
|
||||||
from server.knowledge_base.migrate import create_tables
|
from server.knowledge_base.migrate import create_tables
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Tuple, List, Dict
|
from typing import List, Dict
|
||||||
from configs import VERSION
|
from configs import VERSION
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.3.0",
|
||||||
|
message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动,0.2.x中相关功能将废弃",
|
||||||
|
removal="0.3.0")
|
||||||
def create_controller_app(
|
def create_controller_app(
|
||||||
dispatch_method: str,
|
dispatch_method: str,
|
||||||
log_level: str = "INFO",
|
log_level: str = "INFO",
|
||||||
@ -88,7 +90,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
|
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
|
if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
|
||||||
from fastchat.serve.base_model_worker import app
|
from fastchat.serve.base_model_worker import app
|
||||||
worker = ""
|
worker = ""
|
||||||
# 在线模型API
|
# 在线模型API
|
||||||
@ -107,12 +109,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
import fastchat.serve.vllm_worker
|
import fastchat.serve.vllm_worker
|
||||||
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
|
||||||
from vllm import AsyncLLMEngine
|
from vllm import AsyncLLMEngine
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
|
||||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
args.tokenizer = args.model_path
|
||||||
args.tokenizer_mode = 'auto'
|
args.tokenizer_mode = 'auto'
|
||||||
args.trust_remote_code= True
|
args.trust_remote_code = True
|
||||||
args.download_dir= None
|
args.download_dir = None
|
||||||
args.load_format = 'auto'
|
args.load_format = 'auto'
|
||||||
args.dtype = 'auto'
|
args.dtype = 'auto'
|
||||||
args.seed = 0
|
args.seed = 0
|
||||||
@ -122,13 +124,13 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
args.block_size = 16
|
args.block_size = 16
|
||||||
args.swap_space = 4 # GiB
|
args.swap_space = 4 # GiB
|
||||||
args.gpu_memory_utilization = 0.90
|
args.gpu_memory_utilization = 0.90
|
||||||
args.max_num_batched_tokens = None # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
|
args.max_num_batched_tokens = None # 一个批次中的最大令牌(tokens)数量,这个取决于你的显卡和大模型设置,设置太大显存会不够
|
||||||
args.max_num_seqs = 256
|
args.max_num_seqs = 256
|
||||||
args.disable_log_stats = False
|
args.disable_log_stats = False
|
||||||
args.conv_template = None
|
args.conv_template = None
|
||||||
args.limit_worker_concurrency = 5
|
args.limit_worker_concurrency = 5
|
||||||
args.no_register = False
|
args.no_register = False
|
||||||
args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
args.num_gpus = 1 # vllm worker的切分是tensor并行,这里填写显卡的数量
|
||||||
args.engine_use_ray = False
|
args.engine_use_ray = False
|
||||||
args.disable_log_requests = False
|
args.disable_log_requests = False
|
||||||
|
|
||||||
@ -138,10 +140,10 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
args.quantization = None
|
args.quantization = None
|
||||||
args.max_log_len = None
|
args.max_log_len = None
|
||||||
args.tokenizer_revision = None
|
args.tokenizer_revision = None
|
||||||
|
|
||||||
# 0.2.2 vllm需要新加的参数
|
# 0.2.2 vllm需要新加的参数
|
||||||
args.max_paddings = 256
|
args.max_paddings = 256
|
||||||
|
|
||||||
if args.model_path:
|
if args.model_path:
|
||||||
args.model = args.model_path
|
args.model = args.model_path
|
||||||
if args.num_gpus > 1:
|
if args.num_gpus > 1:
|
||||||
@ -154,16 +156,16 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
worker = VLLMWorker(
|
worker = VLLMWorker(
|
||||||
controller_addr = args.controller_address,
|
controller_addr=args.controller_address,
|
||||||
worker_addr = args.worker_address,
|
worker_addr=args.worker_address,
|
||||||
worker_id = worker_id,
|
worker_id=worker_id,
|
||||||
model_path = args.model_path,
|
model_path=args.model_path,
|
||||||
model_names = args.model_names,
|
model_names=args.model_names,
|
||||||
limit_worker_concurrency = args.limit_worker_concurrency,
|
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||||
no_register = args.no_register,
|
no_register=args.no_register,
|
||||||
llm_engine = engine,
|
llm_engine=engine,
|
||||||
conv_template = args.conv_template,
|
conv_template=args.conv_template,
|
||||||
)
|
)
|
||||||
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
sys.modules["fastchat.serve.vllm_worker"].engine = engine
|
||||||
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||||
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
|
||||||
@ -171,7 +173,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||||||
else:
|
else:
|
||||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||||
|
|
||||||
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
args.gpus = "0" # GPU的编号,如果有多个GPU,可以设置为"0,1,2,3"
|
||||||
args.max_gpu_memory = "22GiB"
|
args.max_gpu_memory = "22GiB"
|
||||||
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
args.num_gpus = 1 # model worker的切分是model并行,这里填写显卡的数量
|
||||||
|
|
||||||
@ -325,7 +327,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
|
|||||||
|
|
||||||
with get_httpx_client() as client:
|
with get_httpx_client() as client:
|
||||||
r = client.post(worker_address + "/release",
|
r = client.post(worker_address + "/release",
|
||||||
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
|
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
msg = f"failed to release model: {model_name}"
|
msg = f"failed to release model: {model_name}"
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
@ -393,8 +395,8 @@ def run_model_worker(
|
|||||||
# add interface to release and load model
|
# add interface to release and load model
|
||||||
@app.post("/release")
|
@app.post("/release")
|
||||||
def release_model(
|
def release_model(
|
||||||
new_model_name: str = Body(None, description="释放后加载该模型"),
|
new_model_name: str = Body(None, description="释放后加载该模型"),
|
||||||
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
if keep_origin:
|
if keep_origin:
|
||||||
if new_model_name:
|
if new_model_name:
|
||||||
@ -416,7 +418,7 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
|
|||||||
set_httpx_config()
|
set_httpx_config()
|
||||||
|
|
||||||
controller_addr = fschat_controller_address()
|
controller_addr = fschat_controller_address()
|
||||||
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
|
app = create_openai_api_app(controller_addr, log_level=log_level)
|
||||||
_set_app_event(app, started_event)
|
_set_app_event(app, started_event)
|
||||||
|
|
||||||
host = FSCHAT_OPENAI_API["host"]
|
host = FSCHAT_OPENAI_API["host"]
|
||||||
@ -450,13 +452,13 @@ def run_webui(started_event: mp.Event = None, run_mode: str = None):
|
|||||||
port = WEBUI_SERVER["port"]
|
port = WEBUI_SERVER["port"]
|
||||||
|
|
||||||
cmd = ["streamlit", "run", "webui.py",
|
cmd = ["streamlit", "run", "webui.py",
|
||||||
"--server.address", host,
|
"--server.address", host,
|
||||||
"--server.port", str(port),
|
"--server.port", str(port),
|
||||||
"--theme.base", "light",
|
"--theme.base", "light",
|
||||||
"--theme.primaryColor", "#165dff",
|
"--theme.primaryColor", "#165dff",
|
||||||
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
"--theme.secondaryBackgroundColor", "#f5f5f5",
|
||||||
"--theme.textColor", "#000000",
|
"--theme.textColor", "#000000",
|
||||||
]
|
]
|
||||||
if run_mode == "lite":
|
if run_mode == "lite":
|
||||||
cmd += [
|
cmd += [
|
||||||
"--",
|
"--",
|
||||||
@ -605,8 +607,10 @@ async def start_main_server():
|
|||||||
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
|
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
|
||||||
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
|
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(signal_received, frame):
|
def f(signal_received, frame):
|
||||||
raise KeyboardInterrupt(f"{signalname} received")
|
raise KeyboardInterrupt(f"{signalname} received")
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
# This will be inherited by the child process if it is forked (not spawned)
|
# This will be inherited by the child process if it is forked (not spawned)
|
||||||
@ -701,8 +705,8 @@ async def start_main_server():
|
|||||||
for model_name in args.model_name:
|
for model_name in args.model_name:
|
||||||
config = get_model_worker_config(model_name)
|
config = get_model_worker_config(model_name)
|
||||||
if (config.get("online_api")
|
if (config.get("online_api")
|
||||||
and config.get("worker_class")
|
and config.get("worker_class")
|
||||||
and model_name in FSCHAT_MODEL_WORKERS):
|
and model_name in FSCHAT_MODEL_WORKERS):
|
||||||
e = manager.Event()
|
e = manager.Event()
|
||||||
model_worker_started.append(e)
|
model_worker_started.append(e)
|
||||||
process = Process(
|
process = Process(
|
||||||
@ -742,12 +746,12 @@ async def start_main_server():
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# 保证任务收到SIGINT后,能够正常退出
|
# 保证任务收到SIGINT后,能够正常退出
|
||||||
if p:= processes.get("controller"):
|
if p := processes.get("controller"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
controller_started.wait() # 等待controller启动完成
|
controller_started.wait() # 等待controller启动完成
|
||||||
|
|
||||||
if p:= processes.get("openai_api"):
|
if p := processes.get("openai_api"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
|
|
||||||
@ -763,24 +767,24 @@ async def start_main_server():
|
|||||||
for e in model_worker_started:
|
for e in model_worker_started:
|
||||||
e.wait()
|
e.wait()
|
||||||
|
|
||||||
if p:= processes.get("api"):
|
if p := processes.get("api"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
api_started.wait() # 等待api.py启动完成
|
api_started.wait() # 等待api.py启动完成
|
||||||
|
|
||||||
if p:= processes.get("webui"):
|
if p := processes.get("webui"):
|
||||||
p.start()
|
p.start()
|
||||||
p.name = f"{p.name} ({p.pid})"
|
p.name = f"{p.name} ({p.pid})"
|
||||||
webui_started.wait() # 等待webui.py启动完成
|
webui_started.wait() # 等待webui.py启动完成
|
||||||
|
|
||||||
dump_server_info(after_start=True, args=args)
|
dump_server_info(after_start=True, args=args)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
cmd = queue.get() # 收到切换模型的消息
|
cmd = queue.get() # 收到切换模型的消息
|
||||||
e = manager.Event()
|
e = manager.Event()
|
||||||
if isinstance(cmd, list):
|
if isinstance(cmd, list):
|
||||||
model_name, cmd, new_model_name = cmd
|
model_name, cmd, new_model_name = cmd
|
||||||
if cmd == "start": # 运行新模型
|
if cmd == "start": # 运行新模型
|
||||||
logger.info(f"准备启动新模型进程:{new_model_name}")
|
logger.info(f"准备启动新模型进程:{new_model_name}")
|
||||||
process = Process(
|
process = Process(
|
||||||
target=run_model_worker,
|
target=run_model_worker,
|
||||||
@ -831,7 +835,6 @@ async def start_main_server():
|
|||||||
else:
|
else:
|
||||||
logger.error(f"未找到模型进程:{model_name}")
|
logger.error(f"未找到模型进程:{model_name}")
|
||||||
|
|
||||||
|
|
||||||
# for process in processes.get("model_worker", {}).values():
|
# for process in processes.get("model_worker", {}).values():
|
||||||
# process.join()
|
# process.join()
|
||||||
# for process in processes.get("online_api", {}).values():
|
# for process in processes.get("online_api", {}).values():
|
||||||
@ -866,10 +869,9 @@ async def start_main_server():
|
|||||||
for p in processes.values():
|
for p in processes.values():
|
||||||
logger.info("Process status: %s", p)
|
logger.info("Process status: %s", p)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 确保数据库表被创建
|
|
||||||
create_tables()
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
create_tables()
|
||||||
if sys.version_info < (3, 10):
|
if sys.version_info < (3, 10):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
else:
|
else:
|
||||||
@ -879,16 +881,15 @@ if __name__ == "__main__":
|
|||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
# 同步调用协程代码
|
|
||||||
loop.run_until_complete(start_main_server())
|
|
||||||
|
|
||||||
|
loop.run_until_complete(start_main_server())
|
||||||
|
|
||||||
# 服务启动后接口调用示例:
|
# 服务启动后接口调用示例:
|
||||||
# import openai
|
# import openai
|
||||||
# openai.api_key = "EMPTY" # Not support yet
|
# openai.api_key = "EMPTY" # Not support yet
|
||||||
# openai.api_base = "http://localhost:8888/v1"
|
# openai.api_base = "http://localhost:8888/v1"
|
||||||
|
|
||||||
# model = "chatglm2-6b"
|
# model = "chatglm3-6b"
|
||||||
|
|
||||||
# # create a chat completion
|
# # create a chat completion
|
||||||
# completion = openai.ChatCompletion.create(
|
# completion = openai.ChatCompletion.create(
|
||||||
|
|||||||
BIN
tests/samples/ocr_test.docx
Normal file
BIN
tests/samples/ocr_test.docx
Normal file
Binary file not shown.
BIN
tests/samples/ocr_test.pptx
Normal file
BIN
tests/samples/ocr_test.pptx
Normal file
Binary file not shown.
@ -6,13 +6,12 @@ from datetime import datetime
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, LLM_MODELS,
|
||||||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
||||||
from server.knowledge_base.utils import LOADER_DICT
|
from server.knowledge_base.utils import LOADER_DICT
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
"img",
|
"img",
|
||||||
@ -127,7 +126,6 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
chat_box.use_chat_name(conversation_name)
|
chat_box.use_chat_name(conversation_name)
|
||||||
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
conversation_id = st.session_state["conversation_ids"][conversation_name]
|
||||||
|
|
||||||
# TODO: 对话模型与会话绑定
|
|
||||||
def on_mode_change():
|
def on_mode_change():
|
||||||
mode = st.session_state.dialogue_mode
|
mode = st.session_state.dialogue_mode
|
||||||
text = f"已切换到 {mode} 模式。"
|
text = f"已切换到 {mode} 模式。"
|
||||||
@ -138,11 +136,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
st.toast(text)
|
st.toast(text)
|
||||||
|
|
||||||
dialogue_modes = ["LLM 对话",
|
dialogue_modes = ["LLM 对话",
|
||||||
"知识库问答",
|
"知识库问答",
|
||||||
"文件对话",
|
"文件对话",
|
||||||
"搜索引擎问答",
|
"搜索引擎问答",
|
||||||
"自定义Agent问答",
|
"自定义Agent问答",
|
||||||
]
|
]
|
||||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||||
dialogue_modes,
|
dialogue_modes,
|
||||||
index=0,
|
index=0,
|
||||||
@ -166,12 +164,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
available_models = []
|
available_models = []
|
||||||
config_models = api.list_config_models()
|
config_models = api.list_config_models()
|
||||||
if not is_lite:
|
if not is_lite:
|
||||||
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
|
for k, v in config_models.get("local", {}).items():
|
||||||
if (v.get("model_path_exists")
|
if (v.get("model_path_exists")
|
||||||
and k not in running_models):
|
and k not in running_models):
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
|
for k, v in config_models.get("online", {}).items():
|
||||||
if not v.get("provider") and k not in running_models:
|
if not v.get("provider") and k not in running_models and k in LLM_MODELS:
|
||||||
available_models.append(k)
|
available_models.append(k)
|
||||||
llm_models = running_models + available_models
|
llm_models = running_models + available_models
|
||||||
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
cur_llm_model = st.session_state.get("cur_llm_model", default_model)
|
||||||
@ -250,14 +248,14 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
elif dialogue_mode == "文件对话":
|
elif dialogue_mode == "文件对话":
|
||||||
with st.expander("文件对话配置", True):
|
with st.expander("文件对话配置", True):
|
||||||
files = st.file_uploader("上传知识文件:",
|
files = st.file_uploader("上传知识文件:",
|
||||||
[i for ls in LOADER_DICT.values() for i in ls],
|
[i for ls in LOADER_DICT.values() for i in ls],
|
||||||
accept_multiple_files=True,
|
accept_multiple_files=True,
|
||||||
)
|
)
|
||||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)
|
||||||
|
|
||||||
## Bge 模型会超过1
|
## Bge 模型会超过1
|
||||||
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
|
||||||
if st.button("开始上传", disabled=len(files)==0):
|
if st.button("开始上传", disabled=len(files) == 0):
|
||||||
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
search_engine_list = api.list_search_engines()
|
search_engine_list = api.list_search_engines()
|
||||||
@ -279,9 +277,9 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
|
||||||
|
|
||||||
def on_feedback(
|
def on_feedback(
|
||||||
feedback,
|
feedback,
|
||||||
message_id: str = "",
|
message_id: str = "",
|
||||||
history_index: int = -1,
|
history_index: int = -1,
|
||||||
):
|
):
|
||||||
reason = feedback["text"]
|
reason = feedback["text"]
|
||||||
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
||||||
@ -296,7 +294,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||||
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
|
||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
history = get_messages_history(history_len)
|
history = get_messages_history(history_len)
|
||||||
@ -306,11 +304,11 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
text = ""
|
text = ""
|
||||||
message_id = ""
|
message_id = ""
|
||||||
r = api.chat_chat(prompt,
|
r = api.chat_chat(prompt,
|
||||||
history=history,
|
history=history,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
prompt_name=prompt_template_name,
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature)
|
temperature=temperature)
|
||||||
for t in r:
|
for t in r:
|
||||||
if error_msg := check_error_msg(t): # check whether error occured
|
if error_msg := check_error_msg(t): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
@ -321,12 +319,12 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
}
|
}
|
||||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||||
chat_box.show_feedback(**feedback_kwargs,
|
chat_box.show_feedback(**feedback_kwargs,
|
||||||
key=message_id,
|
key=message_id,
|
||||||
on_submit=on_feedback,
|
on_submit=on_feedback,
|
||||||
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
|
||||||
|
|
||||||
elif dialogue_mode == "自定义Agent问答":
|
elif dialogue_mode == "自定义Agent问答":
|
||||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||||
@ -373,13 +371,13 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
for d in api.knowledge_base_chat(prompt,
|
for d in api.knowledge_base_chat(prompt,
|
||||||
knowledge_base_name=selected_kb,
|
knowledge_base_name=selected_kb,
|
||||||
top_k=kb_top_k,
|
top_k=kb_top_k,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
history=history,
|
history=history,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
prompt_name=prompt_template_name,
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature):
|
temperature=temperature):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
elif chunk := d.get("answer"):
|
elif chunk := d.get("answer"):
|
||||||
@ -397,13 +395,13 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
for d in api.file_chat(prompt,
|
for d in api.file_chat(prompt,
|
||||||
knowledge_id=st.session_state["file_chat_id"],
|
knowledge_id=st.session_state["file_chat_id"],
|
||||||
top_k=kb_top_k,
|
top_k=kb_top_k,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
history=history,
|
history=history,
|
||||||
model=llm_model,
|
model=llm_model,
|
||||||
prompt_name=prompt_template_name,
|
prompt_name=prompt_template_name,
|
||||||
temperature=temperature):
|
temperature=temperature):
|
||||||
if error_msg := check_error_msg(d): # check whether error occured
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
st.error(error_msg)
|
st.error(error_msg)
|
||||||
elif chunk := d.get("answer"):
|
elif chunk := d.get("answer"):
|
||||||
@ -455,4 +453,4 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||||||
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
|
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
|
||||||
mime="text/markdown",
|
mime="text/markdown",
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,15 +7,12 @@ from server.knowledge_base.utils import get_file_path, LOADER_DICT
|
|||||||
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
|
||||||
from typing import Literal, Dict, Tuple
|
from typing import Literal, Dict, Tuple
|
||||||
from configs import (kbs_config,
|
from configs import (kbs_config,
|
||||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||||
from server.utils import list_embed_models, list_online_embed_models
|
from server.utils import list_embed_models, list_online_embed_models
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
# SENTENCE_SIZE = 100
|
|
||||||
|
|
||||||
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
cell_renderer = JsCode("""function(params) {if(params.value==true){return '✓'}else{return '×'}}""")
|
||||||
|
|
||||||
|
|
||||||
@ -32,7 +29,7 @@ def config_aggrid(
|
|||||||
gb.configure_selection(
|
gb.configure_selection(
|
||||||
selection_mode=selection_mode,
|
selection_mode=selection_mode,
|
||||||
use_checkbox=use_checkbox,
|
use_checkbox=use_checkbox,
|
||||||
# pre_selected_rows=st.session_state.get("selected_rows", [0]),
|
pre_selected_rows=st.session_state.get("selected_rows", [0]),
|
||||||
)
|
)
|
||||||
gb.configure_pagination(
|
gb.configure_pagination(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
@ -59,7 +56,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
try:
|
try:
|
||||||
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
|
st.error(
|
||||||
|
"获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。")
|
||||||
st.stop()
|
st.stop()
|
||||||
kb_names = list(kb_list.keys())
|
kb_names = list(kb_list.keys())
|
||||||
|
|
||||||
@ -150,7 +148,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
[i for ls in LOADER_DICT.values() for i in ls],
|
[i for ls in LOADER_DICT.values() for i in ls],
|
||||||
accept_multiple_files=True,
|
accept_multiple_files=True,
|
||||||
)
|
)
|
||||||
kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None, key=None,
|
kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None,
|
||||||
|
key=None,
|
||||||
help=None, on_change=None, args=None, kwargs=None)
|
help=None, on_change=None, args=None, kwargs=None)
|
||||||
|
|
||||||
if kb_info != st.session_state["selected_kb_info"]:
|
if kb_info != st.session_state["selected_kb_info"]:
|
||||||
@ -200,8 +199,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
doc_details = doc_details[[
|
doc_details = doc_details[[
|
||||||
"No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db",
|
"No", "file_name", "document_loader", "text_splitter", "docs_count", "in_folder", "in_db",
|
||||||
]]
|
]]
|
||||||
# doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×")
|
doc_details["in_folder"] = doc_details["in_folder"].replace(True, "✓").replace(False, "×")
|
||||||
# doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×")
|
doc_details["in_db"] = doc_details["in_db"].replace(True, "✓").replace(False, "×")
|
||||||
gb = config_aggrid(
|
gb = config_aggrid(
|
||||||
doc_details,
|
doc_details,
|
||||||
{
|
{
|
||||||
@ -252,7 +251,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
st.write()
|
st.write()
|
||||||
# 将文件分词并加载到向量库中
|
# 将文件分词并加载到向量库中
|
||||||
if cols[1].button(
|
if cols[1].button(
|
||||||
"重新添加至向量库" if selected_rows and (pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库",
|
"重新添加至向量库" if selected_rows and (
|
||||||
|
pd.DataFrame(selected_rows)["in_db"]).any() else "添加至向量库",
|
||||||
disabled=not file_exists(kb, selected_rows)[0],
|
disabled=not file_exists(kb, selected_rows)[0],
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
):
|
):
|
||||||
@ -285,39 +285,39 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# cols = st.columns(3)
|
cols = st.columns(3)
|
||||||
|
|
||||||
# if cols[0].button(
|
if cols[0].button(
|
||||||
# "依据源文件重建向量库",
|
"依据源文件重建向量库",
|
||||||
# # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
|
help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。",
|
||||||
# use_container_width=True,
|
use_container_width=True,
|
||||||
# type="primary",
|
type="primary",
|
||||||
# ):
|
):
|
||||||
# with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
|
with st.spinner("向量库重构中,请耐心等待,勿刷新或关闭页面。"):
|
||||||
# empty = st.empty()
|
empty = st.empty()
|
||||||
# empty.progress(0.0, "")
|
empty.progress(0.0, "")
|
||||||
# for d in api.recreate_vector_store(kb,
|
for d in api.recreate_vector_store(kb,
|
||||||
# chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
# chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
# zh_title_enhance=zh_title_enhance):
|
zh_title_enhance=zh_title_enhance):
|
||||||
# if msg := check_error_msg(d):
|
if msg := check_error_msg(d):
|
||||||
# st.toast(msg)
|
st.toast(msg)
|
||||||
# else:
|
else:
|
||||||
# empty.progress(d["finished"] / d["total"], d["msg"])
|
empty.progress(d["finished"] / d["total"], d["msg"])
|
||||||
# st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
# if cols[2].button(
|
if cols[2].button(
|
||||||
# "删除知识库",
|
"删除知识库",
|
||||||
# use_container_width=True,
|
use_container_width=True,
|
||||||
# ):
|
):
|
||||||
# ret = api.delete_knowledge_base(kb)
|
ret = api.delete_knowledge_base(kb)
|
||||||
# st.toast(ret.get("msg", " "))
|
st.toast(ret.get("msg", " "))
|
||||||
# time.sleep(1)
|
time.sleep(1)
|
||||||
# st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
# with st.sidebar:
|
with st.sidebar:
|
||||||
# keyword = st.text_input("查询关键字")
|
keyword = st.text_input("查询关键字")
|
||||||
# top_k = st.slider("匹配条数", 1, 100, 3)
|
top_k = st.slider("匹配条数", 1, 100, 3)
|
||||||
|
|
||||||
st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。")
|
st.write("文件内文档列表。双击进行修改,在删除列填入 Y 可删除对应行。")
|
||||||
docs = []
|
docs = []
|
||||||
@ -325,11 +325,12 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
if selected_rows:
|
if selected_rows:
|
||||||
file_name = selected_rows[0]["file_name"]
|
file_name = selected_rows[0]["file_name"]
|
||||||
docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name)
|
docs = api.search_kb_docs(knowledge_base_name=selected_kb, file_name=file_name)
|
||||||
data = [{"seq": i+1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"),
|
data = [
|
||||||
"type": x["type"],
|
{"seq": i + 1, "id": x["id"], "page_content": x["page_content"], "source": x["metadata"].get("source"),
|
||||||
"metadata": json.dumps(x["metadata"], ensure_ascii=False),
|
"type": x["type"],
|
||||||
"to_del": "",
|
"metadata": json.dumps(x["metadata"], ensure_ascii=False),
|
||||||
} for i, x in enumerate(docs)]
|
"to_del": "",
|
||||||
|
} for i, x in enumerate(docs)]
|
||||||
df = pd.DataFrame(data)
|
df = pd.DataFrame(data)
|
||||||
|
|
||||||
gb = GridOptionsBuilder.from_dataframe(df)
|
gb = GridOptionsBuilder.from_dataframe(df)
|
||||||
@ -343,22 +344,24 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||||||
edit_docs = AgGrid(df, gb.build())
|
edit_docs = AgGrid(df, gb.build())
|
||||||
|
|
||||||
if st.button("保存更改"):
|
if st.button("保存更改"):
|
||||||
# origin_docs = {x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in docs}
|
origin_docs = {
|
||||||
|
x["id"]: {"page_content": x["page_content"], "type": x["type"], "metadata": x["metadata"]} for x in
|
||||||
|
docs}
|
||||||
changed_docs = []
|
changed_docs = []
|
||||||
for index, row in edit_docs.data.iterrows():
|
for index, row in edit_docs.data.iterrows():
|
||||||
# origin_doc = origin_docs[row["id"]]
|
origin_doc = origin_docs[row["id"]]
|
||||||
# if row["page_content"] != origin_doc["page_content"]:
|
if row["page_content"] != origin_doc["page_content"]:
|
||||||
if row["to_del"] not in ["Y", "y", 1]:
|
if row["to_del"] not in ["Y", "y", 1]:
|
||||||
changed_docs.append({
|
changed_docs.append({
|
||||||
"page_content": row["page_content"],
|
"page_content": row["page_content"],
|
||||||
"type": row["type"],
|
"type": row["type"],
|
||||||
"metadata": json.loads(row["metadata"]),
|
"metadata": json.loads(row["metadata"]),
|
||||||
})
|
})
|
||||||
|
|
||||||
if changed_docs:
|
if changed_docs:
|
||||||
if api.update_kb_docs(knowledge_base_name=selected_kb,
|
if api.update_kb_docs(knowledge_base_name=selected_kb,
|
||||||
file_names=[file_name],
|
file_names=[file_name],
|
||||||
docs={file_name: changed_docs}):
|
docs={file_name: changed_docs}):
|
||||||
st.toast("更新文档成功")
|
st.toast("更新文档成功")
|
||||||
else:
|
else:
|
||||||
st.toast("更新文档失败")
|
st.toast("更新文档失败")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# 该文件封装了对api.py的请求,可以被不同的webui使用
|
# 该文件封装了对api.py的请求,可以被不同的webui使用
|
||||||
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
|
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
|
||||||
|
|
||||||
|
|
||||||
from typing import *
|
from typing import *
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
|
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
|
||||||
@ -27,7 +26,7 @@ from io import BytesIO
|
|||||||
from server.utils import set_httpx_config, api_address, get_httpx_client
|
from server.utils import set_httpx_config, api_address, get_httpx_client
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
from langchain_core._api import deprecated
|
||||||
|
|
||||||
set_httpx_config()
|
set_httpx_config()
|
||||||
|
|
||||||
@ -36,10 +35,11 @@ class ApiRequest:
|
|||||||
'''
|
'''
|
||||||
api.py调用的封装(同步模式),简化api调用方式
|
api.py调用的封装(同步模式),简化api调用方式
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str = api_address(),
|
base_url: str = api_address(),
|
||||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||||
):
|
):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
@ -55,12 +55,12 @@ class ApiRequest:
|
|||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
params: Union[Dict, List[Tuple], bytes] = None,
|
params: Union[Dict, List[Tuple], bytes] = None,
|
||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
@ -75,13 +75,13 @@ class ApiRequest:
|
|||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Dict = None,
|
data: Dict = None,
|
||||||
json: Dict = None,
|
json: Dict = None,
|
||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
@ -97,13 +97,13 @@ class ApiRequest:
|
|||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Dict = None,
|
data: Dict = None,
|
||||||
json: Dict = None,
|
json: Dict = None,
|
||||||
retry: int = 3,
|
retry: int = 3,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||||
while retry > 0:
|
while retry > 0:
|
||||||
try:
|
try:
|
||||||
@ -118,24 +118,25 @@ class ApiRequest:
|
|||||||
retry -= 1
|
retry -= 1
|
||||||
|
|
||||||
def _httpx_stream2generator(
|
def _httpx_stream2generator(
|
||||||
self,
|
self,
|
||||||
response: contextlib._GeneratorContextManager,
|
response: contextlib._GeneratorContextManager,
|
||||||
as_json: bool = False,
|
as_json: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
||||||
'''
|
'''
|
||||||
|
|
||||||
async def ret_async(response, as_json):
|
async def ret_async(response, as_json):
|
||||||
try:
|
try:
|
||||||
async with response as r:
|
async with response as r:
|
||||||
async for chunk in r.aiter_text(None):
|
async for chunk in r.aiter_text(None):
|
||||||
if not chunk: # fastchat api yield empty bytes on start and end
|
if not chunk: # fastchat api yield empty bytes on start and end
|
||||||
continue
|
continue
|
||||||
if as_json:
|
if as_json:
|
||||||
try:
|
try:
|
||||||
if chunk.startswith("data: "):
|
if chunk.startswith("data: "):
|
||||||
data = json.loads(chunk[6:-2])
|
data = json.loads(chunk[6:-2])
|
||||||
elif chunk.startswith(":"): # skip sse comment line
|
elif chunk.startswith(":"): # skip sse comment line
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
@ -143,7 +144,7 @@ class ApiRequest:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
else:
|
else:
|
||||||
# print(chunk, end="", flush=True)
|
# print(chunk, end="", flush=True)
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -158,20 +159,20 @@ class ApiRequest:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"API通信遇到错误:{e}"
|
msg = f"API通信遇到错误:{e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
yield {"code": 500, "msg": msg}
|
yield {"code": 500, "msg": msg}
|
||||||
|
|
||||||
def ret_sync(response, as_json):
|
def ret_sync(response, as_json):
|
||||||
try:
|
try:
|
||||||
with response as r:
|
with response as r:
|
||||||
for chunk in r.iter_text(None):
|
for chunk in r.iter_text(None):
|
||||||
if not chunk: # fastchat api yield empty bytes on start and end
|
if not chunk: # fastchat api yield empty bytes on start and end
|
||||||
continue
|
continue
|
||||||
if as_json:
|
if as_json:
|
||||||
try:
|
try:
|
||||||
if chunk.startswith("data: "):
|
if chunk.startswith("data: "):
|
||||||
data = json.loads(chunk[6:-2])
|
data = json.loads(chunk[6:-2])
|
||||||
elif chunk.startswith(":"): # skip sse comment line
|
elif chunk.startswith(":"): # skip sse comment line
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
@ -179,7 +180,7 @@ class ApiRequest:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
else:
|
else:
|
||||||
# print(chunk, end="", flush=True)
|
# print(chunk, end="", flush=True)
|
||||||
yield chunk
|
yield chunk
|
||||||
@ -194,7 +195,7 @@ class ApiRequest:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"API通信遇到错误:{e}"
|
msg = f"API通信遇到错误:{e}"
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
yield {"code": 500, "msg": msg}
|
yield {"code": 500, "msg": msg}
|
||||||
|
|
||||||
if self._use_async:
|
if self._use_async:
|
||||||
@ -203,16 +204,17 @@ class ApiRequest:
|
|||||||
return ret_sync(response, as_json)
|
return ret_sync(response, as_json)
|
||||||
|
|
||||||
def _get_response_value(
|
def _get_response_value(
|
||||||
self,
|
self,
|
||||||
response: httpx.Response,
|
response: httpx.Response,
|
||||||
as_json: bool = False,
|
as_json: bool = False,
|
||||||
value_func: Callable = None,
|
value_func: Callable = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
转换同步或异步请求返回的响应
|
转换同步或异步请求返回的响应
|
||||||
`as_json`: 返回json
|
`as_json`: 返回json
|
||||||
`value_func`: 用户可以自定义返回值,该函数接受response或json
|
`value_func`: 用户可以自定义返回值,该函数接受response或json
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def to_json(r):
|
def to_json(r):
|
||||||
try:
|
try:
|
||||||
return r.json()
|
return r.json()
|
||||||
@ -220,7 +222,7 @@ class ApiRequest:
|
|||||||
msg = "API未能返回正确的JSON。" + str(e)
|
msg = "API未能返回正确的JSON。" + str(e)
|
||||||
if log_verbose:
|
if log_verbose:
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
return {"code": 500, "msg": msg, "data": None}
|
return {"code": 500, "msg": msg, "data": None}
|
||||||
|
|
||||||
if value_func is None:
|
if value_func is None:
|
||||||
@ -250,10 +252,10 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
|
||||||
|
|
||||||
def get_prompt_template(
|
def get_prompt_template(
|
||||||
self,
|
self,
|
||||||
type: str = "llm_chat",
|
type: str = "llm_chat",
|
||||||
name: str = "default",
|
name: str = "default",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
data = {
|
data = {
|
||||||
"type": type,
|
"type": type,
|
||||||
@ -297,15 +299,19 @@ class ApiRequest:
|
|||||||
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.3.0",
|
||||||
|
message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
|
||||||
|
removal="0.3.0")
|
||||||
def agent_chat(
|
def agent_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model: str = LLM_MODELS[0],
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/agent_chat 接口
|
对应api.py/chat/agent_chat 接口
|
||||||
@ -327,17 +333,17 @@ class ApiRequest:
|
|||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
def knowledge_base_chat(
|
def knowledge_base_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model: str = LLM_MODELS[0],
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/knowledge_base_chat接口
|
对应api.py/chat/knowledge_base_chat接口
|
||||||
@ -366,28 +372,29 @@ class ApiRequest:
|
|||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
def upload_temp_docs(
|
def upload_temp_docs(
|
||||||
self,
|
self,
|
||||||
files: List[Union[str, Path, bytes]],
|
files: List[Union[str, Path, bytes]],
|
||||||
knowledge_id: str = None,
|
knowledge_id: str = None,
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/upload_tmep_docs接口
|
对应api.py/knowledge_base/upload_tmep_docs接口
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def convert_file(file, filename=None):
|
def convert_file(file, filename=None):
|
||||||
if isinstance(file, bytes): # raw bytes
|
if isinstance(file, bytes): # raw bytes
|
||||||
file = BytesIO(file)
|
file = BytesIO(file)
|
||||||
elif hasattr(file, "read"): # a file io like object
|
elif hasattr(file, "read"): # a file io like object
|
||||||
filename = filename or file.name
|
filename = filename or file.name
|
||||||
else: # a local path
|
else: # a local path
|
||||||
file = Path(file).absolute().open("rb")
|
file = Path(file).absolute().open("rb")
|
||||||
filename = filename or os.path.split(file.name)[-1]
|
filename = filename or os.path.split(file.name)[-1]
|
||||||
return filename, file
|
return filename, file
|
||||||
|
|
||||||
files = [convert_file(file) for file in files]
|
files = [convert_file(file) for file in files]
|
||||||
data={
|
data = {
|
||||||
"knowledge_id": knowledge_id,
|
"knowledge_id": knowledge_id,
|
||||||
"chunk_size": chunk_size,
|
"chunk_size": chunk_size,
|
||||||
"chunk_overlap": chunk_overlap,
|
"chunk_overlap": chunk_overlap,
|
||||||
@ -402,17 +409,17 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def file_chat(
|
def file_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
knowledge_id: str,
|
knowledge_id: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model: str = LLM_MODELS[0],
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/file_chat接口
|
对应api.py/chat/file_chat接口
|
||||||
@ -430,9 +437,6 @@ class ApiRequest:
|
|||||||
"prompt_name": prompt_name,
|
"prompt_name": prompt_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
# print(f"received input message:")
|
|
||||||
# pprint(data)
|
|
||||||
|
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/file_chat",
|
"/chat/file_chat",
|
||||||
json=data,
|
json=data,
|
||||||
@ -440,18 +444,23 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
since="0.3.0",
|
||||||
|
message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
|
||||||
|
removal="0.3.0"
|
||||||
|
)
|
||||||
def search_engine_chat(
|
def search_engine_chat(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
search_engine_name: str,
|
search_engine_name: str,
|
||||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||||
history: List[Dict] = [],
|
history: List[Dict] = [],
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
model: str = LLM_MODELS[0],
|
model: str = LLM_MODELS[0],
|
||||||
temperature: float = TEMPERATURE,
|
temperature: float = TEMPERATURE,
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
prompt_name: str = "default",
|
prompt_name: str = "default",
|
||||||
split_result: bool = False,
|
split_result: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/chat/search_engine_chat接口
|
对应api.py/chat/search_engine_chat接口
|
||||||
@ -482,7 +491,7 @@ class ApiRequest:
|
|||||||
# 知识库相关操作
|
# 知识库相关操作
|
||||||
|
|
||||||
def list_knowledge_bases(
|
def list_knowledge_bases(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/list_knowledge_bases接口
|
对应api.py/knowledge_base/list_knowledge_bases接口
|
||||||
@ -493,10 +502,10 @@ class ApiRequest:
|
|||||||
value_func=lambda r: r.get("data", []))
|
value_func=lambda r: r.get("data", []))
|
||||||
|
|
||||||
def create_knowledge_base(
|
def create_knowledge_base(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
vector_store_type: str = DEFAULT_VS_TYPE,
|
vector_store_type: str = DEFAULT_VS_TYPE,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/create_knowledge_base接口
|
对应api.py/knowledge_base/create_knowledge_base接口
|
||||||
@ -514,8 +523,8 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def delete_knowledge_base(
|
def delete_knowledge_base(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/delete_knowledge_base接口
|
对应api.py/knowledge_base/delete_knowledge_base接口
|
||||||
@ -527,8 +536,8 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def list_kb_docs(
|
def list_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/list_files接口
|
对应api.py/knowledge_base/list_files接口
|
||||||
@ -542,13 +551,13 @@ class ApiRequest:
|
|||||||
value_func=lambda r: r.get("data", []))
|
value_func=lambda r: r.get("data", []))
|
||||||
|
|
||||||
def search_kb_docs(
|
def search_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
query: str = "",
|
query: str = "",
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: int = SCORE_THRESHOLD,
|
score_threshold: int = SCORE_THRESHOLD,
|
||||||
file_name: str = "",
|
file_name: str = "",
|
||||||
metadata: dict = {},
|
metadata: dict = {},
|
||||||
) -> List:
|
) -> List:
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/search_docs接口
|
对应api.py/knowledge_base/search_docs接口
|
||||||
@ -569,9 +578,9 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def update_docs_by_id(
|
def update_docs_by_id(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
docs: Dict[str, Dict],
|
docs: Dict[str, Dict],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/update_docs_by_id接口
|
对应api.py/knowledge_base/update_docs_by_id接口
|
||||||
@ -587,32 +596,33 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response)
|
return self._get_response_value(response)
|
||||||
|
|
||||||
def upload_kb_docs(
|
def upload_kb_docs(
|
||||||
self,
|
self,
|
||||||
files: List[Union[str, Path, bytes]],
|
files: List[Union[str, Path, bytes]],
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
override: bool = False,
|
override: bool = False,
|
||||||
to_vector_store: bool = True,
|
to_vector_store: bool = True,
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
docs: Dict = {},
|
docs: Dict = {},
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/upload_docs接口
|
对应api.py/knowledge_base/upload_docs接口
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def convert_file(file, filename=None):
|
def convert_file(file, filename=None):
|
||||||
if isinstance(file, bytes): # raw bytes
|
if isinstance(file, bytes): # raw bytes
|
||||||
file = BytesIO(file)
|
file = BytesIO(file)
|
||||||
elif hasattr(file, "read"): # a file io like object
|
elif hasattr(file, "read"): # a file io like object
|
||||||
filename = filename or file.name
|
filename = filename or file.name
|
||||||
else: # a local path
|
else: # a local path
|
||||||
file = Path(file).absolute().open("rb")
|
file = Path(file).absolute().open("rb")
|
||||||
filename = filename or os.path.split(file.name)[-1]
|
filename = filename or os.path.split(file.name)[-1]
|
||||||
return filename, file
|
return filename, file
|
||||||
|
|
||||||
files = [convert_file(file) for file in files]
|
files = [convert_file(file) for file in files]
|
||||||
data={
|
data = {
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"override": override,
|
"override": override,
|
||||||
"to_vector_store": to_vector_store,
|
"to_vector_store": to_vector_store,
|
||||||
@ -633,11 +643,11 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def delete_kb_docs(
|
def delete_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
file_names: List[str],
|
file_names: List[str],
|
||||||
delete_content: bool = False,
|
delete_content: bool = False,
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/delete_docs接口
|
对应api.py/knowledge_base/delete_docs接口
|
||||||
@ -655,8 +665,7 @@ class ApiRequest:
|
|||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
|
def update_kb_info(self, knowledge_base_name, kb_info):
|
||||||
def update_kb_info(self,knowledge_base_name,kb_info):
|
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/update_info接口
|
对应api.py/knowledge_base/update_info接口
|
||||||
'''
|
'''
|
||||||
@ -672,15 +681,15 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def update_kb_docs(
|
def update_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
file_names: List[str],
|
file_names: List[str],
|
||||||
override_custom_docs: bool = False,
|
override_custom_docs: bool = False,
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
docs: Dict = {},
|
docs: Dict = {},
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/update_docs接口
|
对应api.py/knowledge_base/update_docs接口
|
||||||
@ -706,14 +715,14 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def recreate_vector_store(
|
def recreate_vector_store(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
allow_empty_kb: bool = True,
|
allow_empty_kb: bool = True,
|
||||||
vs_type: str = DEFAULT_VS_TYPE,
|
vs_type: str = DEFAULT_VS_TYPE,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE,
|
||||||
chunk_overlap=OVERLAP_SIZE,
|
chunk_overlap=OVERLAP_SIZE,
|
||||||
zh_title_enhance=ZH_TITLE_ENHANCE,
|
zh_title_enhance=ZH_TITLE_ENHANCE,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/recreate_vector_store接口
|
对应api.py/knowledge_base/recreate_vector_store接口
|
||||||
@ -738,8 +747,8 @@ class ApiRequest:
|
|||||||
|
|
||||||
# LLM模型相关操作
|
# LLM模型相关操作
|
||||||
def list_running_models(
|
def list_running_models(
|
||||||
self,
|
self,
|
||||||
controller_address: str = None,
|
controller_address: str = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
获取Fastchat中正运行的模型列表
|
获取Fastchat中正运行的模型列表
|
||||||
@ -755,8 +764,7 @@ class ApiRequest:
|
|||||||
"/llm_model/list_running_models",
|
"/llm_model/list_running_models",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", []))
|
||||||
|
|
||||||
|
|
||||||
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
|
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
|
||||||
'''
|
'''
|
||||||
@ -764,6 +772,7 @@ class ApiRequest:
|
|||||||
当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。
|
当 local_first=True 时,优先返回运行中的本地模型,否则优先按LLM_MODELS配置顺序返回。
|
||||||
返回类型为(model_name, is_local_model)
|
返回类型为(model_name, is_local_model)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def ret_sync():
|
def ret_sync():
|
||||||
running_models = self.list_running_models()
|
running_models = self.list_running_models()
|
||||||
if not running_models:
|
if not running_models:
|
||||||
@ -780,7 +789,7 @@ class ApiRequest:
|
|||||||
model = m
|
model = m
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
||||||
model = list(running_models)[0]
|
model = list(running_models)[0]
|
||||||
is_local = not running_models[model].get("online_api")
|
is_local = not running_models[model].get("online_api")
|
||||||
return model, is_local
|
return model, is_local
|
||||||
@ -801,7 +810,7 @@ class ApiRequest:
|
|||||||
model = m
|
model = m
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
if not model: # LLM_MODELS中配置的模型都不在running_models里
|
||||||
model = list(running_models)[0]
|
model = list(running_models)[0]
|
||||||
is_local = not running_models[model].get("online_api")
|
is_local = not running_models[model].get("online_api")
|
||||||
return model, is_local
|
return model, is_local
|
||||||
@ -812,8 +821,8 @@ class ApiRequest:
|
|||||||
return ret_sync()
|
return ret_sync()
|
||||||
|
|
||||||
def list_config_models(
|
def list_config_models(
|
||||||
self,
|
self,
|
||||||
types: List[str] = ["local", "online"],
|
types: List[str] = ["local", "online"],
|
||||||
) -> Dict[str, Dict]:
|
) -> Dict[str, Dict]:
|
||||||
'''
|
'''
|
||||||
获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。
|
获取服务器configs中配置的模型列表,返回形式为{"type": {model_name: config}, ...}。
|
||||||
@ -825,23 +834,23 @@ class ApiRequest:
|
|||||||
"/llm_model/list_config_models",
|
"/llm_model/list_config_models",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||||
|
|
||||||
def get_model_config(
|
def get_model_config(
|
||||||
self,
|
self,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
'''
|
'''
|
||||||
获取服务器上模型配置
|
获取服务器上模型配置
|
||||||
'''
|
'''
|
||||||
data={
|
data = {
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
}
|
}
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/llm_model/get_model_config",
|
"/llm_model/get_model_config",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||||
|
|
||||||
def list_search_engines(self) -> List[str]:
|
def list_search_engines(self) -> List[str]:
|
||||||
'''
|
'''
|
||||||
@ -850,12 +859,12 @@ class ApiRequest:
|
|||||||
response = self.post(
|
response = self.post(
|
||||||
"/server/list_search_engines",
|
"/server/list_search_engines",
|
||||||
)
|
)
|
||||||
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
|
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
|
||||||
|
|
||||||
def stop_llm_model(
|
def stop_llm_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
controller_address: str = None,
|
controller_address: str = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
停止某个LLM模型。
|
停止某个LLM模型。
|
||||||
@ -873,10 +882,10 @@ class ApiRequest:
|
|||||||
return self._get_response_value(response, as_json=True)
|
return self._get_response_value(response, as_json=True)
|
||||||
|
|
||||||
def change_llm_model(
|
def change_llm_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
new_model_name: str,
|
new_model_name: str,
|
||||||
controller_address: str = None,
|
controller_address: str = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
向fastchat controller请求切换LLM模型。
|
向fastchat controller请求切换LLM模型。
|
||||||
@ -959,10 +968,10 @@ class ApiRequest:
|
|||||||
return ret_sync()
|
return ret_sync()
|
||||||
|
|
||||||
def embed_texts(
|
def embed_texts(
|
||||||
self,
|
self,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
to_query: bool = False,
|
to_query: bool = False,
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
'''
|
'''
|
||||||
对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型
|
对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型
|
||||||
@ -979,10 +988,10 @@ class ApiRequest:
|
|||||||
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
|
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
|
||||||
|
|
||||||
def chat_feedback(
|
def chat_feedback(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
score: int,
|
score: int,
|
||||||
reason: str = "",
|
reason: str = "",
|
||||||
) -> int:
|
) -> int:
|
||||||
'''
|
'''
|
||||||
反馈对话评价
|
反馈对话评价
|
||||||
@ -1019,9 +1028,9 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
|
|||||||
return error message if error occured when requests API
|
return error message if error occured when requests API
|
||||||
'''
|
'''
|
||||||
if (isinstance(data, dict)
|
if (isinstance(data, dict)
|
||||||
and key in data
|
and key in data
|
||||||
and "code" in data
|
and "code" in data
|
||||||
and data["code"] == 200):
|
and data["code"] == 200):
|
||||||
return data[key]
|
return data[key]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user