Merge pull request #2748 from chatchat-space/dev

0.2.10(0.2.x Final dev release)
This commit is contained in:
zR 2024-01-22 11:57:20 +08:00 committed by GitHub
commit 09bcec542a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 1106 additions and 1291 deletions

View File

@ -42,7 +42,7 @@
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。 🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v11` 版本所使用代码已更新至本项目 `v0.2.7` 版本。 🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v13` 版本所使用代码已更新至本项目 `v0.2.9` 版本。
🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.6) 已经更新到 ```0.2.7``` 版本。 🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.6) 已经更新到 ```0.2.7``` 版本。
@ -67,10 +67,10 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
### 1. 环境配置 ### 1. 环境配置
+ 首先,确保你的机器安装了 Python 3.8 - 3.10 + 首先,确保你的机器安装了 Python 3.8 - 3.11
``` ```
$ python --version $ python --version
Python 3.10.12 Python 3.11.7
``` ```
接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖 接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖
```shell ```shell
@ -88,6 +88,7 @@ $ pip install -r requirements_webui.txt
# 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。 # 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
``` ```
请注意LangChain-Chatchat `0.2.x` 系列是针对 Langchain `0.0.x` 系列版本的,如果你使用的是 Langchain `0.1.x` 系列版本,需要降级。
### 2 模型下载 ### 2 模型下载
如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding 模型可以从 [HuggingFace](https://huggingface.co/models) 下载。 如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding 模型可以从 [HuggingFace](https://huggingface.co/models) 下载。
@ -141,14 +142,20 @@ $ python startup.py -a
--- ---
## 项目里程碑 ## 项目里程碑
+ `2023年4月`: `Langchain-ChatGLM 0.1.0` 发布,支持基于 ChatGLM-6B 模型的本地知识库问答。
+ `2023年8月`: `Langchain-ChatGLM` 改名为 `Langchain-Chatchat``0.2.0` 发布,使用 `fastchat` 作为模型加载方案,支持更多的模型和数据库。
+ `2023年10月`: `Langchain-Chatchat 0.2.5` 发布,推出 Agent 内容,开源项目在`Founder Park & Zhipu AI & Zilliz` 举办的黑客马拉松获得三等奖。
+ `2023年12月`: `Langchain-Chatchat` 开源项目获得超过 **20K** stars.
+ `2024年1月`: `LangChain 0.1.x` 推出,`Langchain-Chatchat 0.2.x` 停止更新和技术支持,全力研发具有更强应用性的 `Langchain-Chatchat 0.3.x`
+ 🔥 让我们一起期待未来 Chatchat 的故事 ···
--- ---
## 联系我们 ## 联系我们
### Telegram ### Telegram
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9) [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
### 项目交流群 ### 项目交流群
<img src="img/qr_code_82.jpg" alt="二维码" width="300" /> <img src="img/qr_code_85.jpg" alt="二维码" width="300" />
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。 🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
@ -156,4 +163,4 @@ $ python startup.py -a
<img src="img/official_wechat_mp_account.png" alt="二维码" width="300" /> <img src="img/official_wechat_mp_account.png" alt="二维码" width="300" />
🎉 Langchain-Chatchat 项目官方公众号,欢迎扫码关注。 🎉 Langchain-Chatchat 项目官方公众号,欢迎扫码关注。

View File

@ -55,10 +55,10 @@ The main process analysis from the aspect of document process:
🚩 The training or fine-tuning are not involved in the project, but still, one always can improve performance by do 🚩 The training or fine-tuning are not involved in the project, but still, one always can improve performance by do
these. these.
🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v9 the codes are update 🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v13 the codes are update
to v0.2.5. to v0.2.9.
🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) 🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7)
## Pain Points Addressed ## Pain Points Addressed
@ -98,6 +98,7 @@ $ pip install -r requirements_webui.txt
# 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。 # 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
``` ```
Please note that the LangChain-Chachat `0.2.x` series is for the Langchain `0.0.x` series version. If you are using the Langchain `0.1.x` series version, you need to downgrade.
### Model Download ### Model Download
@ -155,6 +156,16 @@ $ python startup.py -a
The above instructions are provided for a quick start. If you need more features or want to customize the launch method, The above instructions are provided for a quick start. If you need more features or want to customize the launch method,
please refer to the [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/). please refer to the [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/).
---
## Project Milestones
+ `April 2023`: `Langchain-ChatGLM 0.1.0` released, supporting local knowledge base question and answer based on the ChatGLM-6B model.
+ `August 2023`: `Langchain-ChatGLM` was renamed to `Langchain-Chatchat`, `0.2.0` was released, using `fastchat` as the model loading solution, supporting more models and databases.
+ `October 2023`: `Langchain-Chachat 0.2.5` was released, Agent content was launched, and the open source project won the third prize in the hackathon held by `Founder Park & Zhipu AI & Zilliz`.
+ `December 2023`: `Langchain-Chachat` open source project received more than **20K** stars.
+ `January 2024`: `LangChain 0.1.x` is launched, `Langchain-Chatchat 0.2.x` will stop updating and technical support, and all efforts will be made to develop `Langchain-Chatchat 0.3.x` with stronger applicability.
+ 🔥 Lets look forward to the future Chatchat stories together···
--- ---
## Contact Us ## Contact Us
@ -163,9 +174,9 @@ please refer to the [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9) [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
### WeChat Group ### WeChat Group
<img src="img/qr_code_67.jpg" alt="二维码" width="300" height="300" /> <img src="img/qr_code_85.jpg" alt="二维码" width="300" height="300" />
### WeChat Official Account ### WeChat Official Account

View File

@ -5,4 +5,4 @@ from .server_config import *
from .prompt_config import * from .prompt_config import *
VERSION = "v0.2.9-preview" VERSION = "v0.2.10"

View File

@ -55,6 +55,9 @@ METAPHOR_API_KEY = ""
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False ZH_TITLE_ENHANCE = False
# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
PDF_OCR_THRESHOLD = (0.6, 0.6)
# 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。 # 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。
KB_INFO = { KB_INFO = {
@ -102,6 +105,10 @@ kbs_config = {
"index_name": "test_index", "index_name": "test_index",
"user": "", "user": "",
"password": "" "password": ""
},
"milvus_kwargs":{
"search_params":{"metric_type": "L2"}, #在此处增加search_params
"index_params":{"metric_type": "L2","index_type": "HNSW"} # 在此处增加index_params
} }
} }

View File

@ -6,9 +6,9 @@ import os
MODEL_ROOT_PATH = "" MODEL_ROOT_PATH = ""
# 选用的 Embedding 名称 # 选用的 Embedding 名称
EMBEDDING_MODEL = "bge-large-zh" EMBEDDING_MODEL = "bge-large-zh-v1.5"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 # Embedding 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
EMBEDDING_DEVICE = "auto" EMBEDDING_DEVICE = "auto"
# 选用的reranker模型 # 选用的reranker模型
@ -26,44 +26,33 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
# 在这里我们使用目前主流的两个离线模型其中chatglm3-6b 为默认加载模型。 # 在这里我们使用目前主流的两个离线模型其中chatglm3-6b 为默认加载模型。
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。 # 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
# chatglm3-6b输出角色标签<|user|>及自问自答的问题详见项目wiki->常见问题->Q20. LLM_MODELS = ["zhipu-api"]
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",
# AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0])
Agent_MODEL = None Agent_MODEL = None
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 # LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
LLM_DEVICE = "auto" LLM_DEVICE = "cuda"
# 历史对话轮数
HISTORY_LEN = 3 HISTORY_LEN = 3
# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度 MAX_TOKENS = 2048
MAX_TOKENS = None
# LLM通用对话参数
TEMPERATURE = 0.7 TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
ONLINE_LLM_MODEL = { ONLINE_LLM_MODEL = {
# 线上模型。请在server_config中为每个在线API设置不同的端口
"openai-api": { "openai-api": {
"model_name": "gpt-3.5-turbo", "model_name": "gpt-4",
"api_base_url": "https://api.openai.com/v1", "api_base_url": "https://api.openai.com/v1",
"api_key": "", "api_key": "",
"openai_proxy": "", "openai_proxy": "",
}, },
# 具体注册及api key获取请前往 http://open.bigmodel.cn # 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": { "zhipu-api": {
"api_key": "", "api_key": "",
"version": "chatglm_turbo", # 可选包括 "chatglm_turbo" "version": "glm-4",
"provider": "ChatGLMWorker", "provider": "ChatGLMWorker",
}, },
# 具体注册及api key获取请前往 https://api.minimax.chat/ # 具体注册及api key获取请前往 https://api.minimax.chat/
"minimax-api": { "minimax-api": {
"group_id": "", "group_id": "",
@ -72,7 +61,6 @@ ONLINE_LLM_MODEL = {
"provider": "MiniMaxWorker", "provider": "MiniMaxWorker",
}, },
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/ # 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
"xinghuo-api": { "xinghuo-api": {
"APPID": "", "APPID": "",
@ -93,8 +81,8 @@ ONLINE_LLM_MODEL = {
# 火山方舟 API文档参考 https://www.volcengine.com/docs/82379 # 火山方舟 API文档参考 https://www.volcengine.com/docs/82379
"fangzhou-api": { "fangzhou-api": {
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model" 更多的见文档模型支持列表中方舟部分。 "version": "chatglm-6b-model",
"version_url": "", # 可以不填写version直接填写在方舟申请模型发布的API地址 "version_url": "",
"api_key": "", "api_key": "",
"secret_key": "", "secret_key": "",
"provider": "FangZhouWorker", "provider": "FangZhouWorker",
@ -102,15 +90,15 @@ ONLINE_LLM_MODEL = {
# 阿里云通义千问 API文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details # 阿里云通义千问 API文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
"qwen-api": { "qwen-api": {
"version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus" "version": "qwen-max",
"api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建 "api_key": "",
"provider": "QwenWorker", "provider": "QwenWorker",
"embed_model": "text-embedding-v1" # embedding 模型名称 "embed_model": "text-embedding-v1" # embedding 模型名称
}, },
# 百川 API申请方式请参考 https://www.baichuan-ai.com/home#api-enter # 百川 API申请方式请参考 https://www.baichuan-ai.com/home#api-enter
"baichuan-api": { "baichuan-api": {
"version": "Baichuan2-53B", # 当前支持 "Baichuan2-53B" 见官方文档。 "version": "Baichuan2-53B",
"api_key": "", "api_key": "",
"secret_key": "", "secret_key": "",
"provider": "BaiChuanWorker", "provider": "BaiChuanWorker",
@ -132,6 +120,11 @@ ONLINE_LLM_MODEL = {
"secret_key": "", "secret_key": "",
"provider": "TianGongWorker", "provider": "TianGongWorker",
}, },
# Gemini API (开发组未测试由社群提供只支持prohttps://makersuite.google.com/或者google cloud使用前先确认网络正常使用代理请在项目启动python startup.py -a)环境内设置https_proxy环境变量
"gemini-api": {
"api_key": "",
"provider": "GeminiWorker",
}
} }
@ -143,6 +136,7 @@ ONLINE_LLM_MODEL = {
# - GanymedeNil/text2vec-large-chinese # - GanymedeNil/text2vec-large-chinese
# - text2vec-large-chinese # - text2vec-large-chinese
# 2.2 如果以上本地路径不存在则使用huggingface模型 # 2.2 如果以上本地路径不存在则使用huggingface模型
MODEL_PATH = { MODEL_PATH = {
"embed_model": { "embed_model": {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh", "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
@ -161,7 +155,7 @@ MODEL_PATH = {
"bge-large-zh": "BAAI/bge-large-zh", "bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct", "bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5", "bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5", "bge-large-zh-v1.5": "/share/home/zyx/Models/bge-large-zh-v1.5",
"piccolo-base-zh": "sensenova/piccolo-base-zh", "piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh", "piccolo-large-zh": "sensenova/piccolo-large-zh",
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large", "nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
@ -169,55 +163,55 @@ MODEL_PATH = {
}, },
"llm_model": { "llm_model": {
# 以下部分模型并未完全测试仅根据fastchat和vllm模型的模型列表推定支持
"chatglm2-6b": "THUDM/chatglm2-6b", "chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b", "chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"chatglm3-6b-base": "THUDM/chatglm3-6b-base",
"Qwen-1_8B": "Qwen/Qwen-1_8B", "Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat", "Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8", "Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
"Qwen-7B": "Qwen/Qwen-7B", "Qwen-1_8B-Chat": "/media/checkpoint/Qwen-1_8B-Chat",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
# 在新版的transformers下需要手动修改模型的config.json文件在quantization_config字典中
# 增加`disable_exllama:true` 字段才能启动qwen的量化模型
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
"Qwen-72B": "Qwen/Qwen-72B",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat", "Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat", "baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat", "baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat",
"aquila-7b": "BAAI/Aquila-7B", "baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"internlm-7b": "internlm/internlm-7b", "internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b", "internlm-chat-7b": "internlm/internlm-chat-7b",
"internlm2-chat-7b": "internlm/internlm2-chat-7b",
"internlm2-chat-20b": "internlm/internlm2-chat-20b",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
"falcon-7b": "tiiuae/falcon-7b", "falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b", "falcon-40b": "tiiuae/falcon-40b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b", "falcon-rw-7b": "tiiuae/falcon-rw-7b",
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.5": "lmsys/vicuna-13b-v1.5",
"koala": "young-geng/koala",
"mpt-7b": "mosaicml/mpt-7b",
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
"mpt-30b": "mosaicml/mpt-30b",
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"gpt2": "gpt2", "gpt2": "gpt2",
"gpt2-xl": "gpt2-xl", "gpt2-xl": "gpt2-xl",
"gpt-j-6b": "EleutherAI/gpt-j-6b", "gpt-j-6b": "EleutherAI/gpt-j-6b",
"gpt4all-j": "nomic-ai/gpt4all-j", "gpt4all-j": "nomic-ai/gpt4all-j",
"gpt-neox-20b": "EleutherAI/gpt-neox-20b", "gpt-neox-20b": "EleutherAI/gpt-neox-20b",
@ -225,63 +219,50 @@ MODEL_PATH = {
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b", "dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b", "stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
"koala": "young-geng/koala",
"mpt-7b": "mosaicml/mpt-7b",
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
"mpt-30b": "mosaicml/mpt-30b",
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
"Yi-34B-Chat": "01-ai/Yi-34B-Chat",
}, },
"reranker":{ "reranker": {
"bge-reranker-large":"BAAI/bge-reranker-large", "bge-reranker-large": "BAAI/bge-reranker-large",
"bge-reranker-base":"BAAI/bge-reranker-base", "bge-reranker-base": "BAAI/bge-reranker-base",
#TODO 增加在线reranker如cohere
} }
} }
# 通常情况下不需要更改以下内容 # 通常情况下不需要更改以下内容
# nltk 模型存储路径 # nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data") NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 使用VLLM可能导致模型推理能力下降无法完成Agent任务
VLLM_MODEL_DICT = { VLLM_MODEL_DICT = {
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"chatglm2-6b": "THUDM/chatglm2-6b", "chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b", "chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat", "BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k", "BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
# 注意bloom系列的tokenizer与model是分离的因此虽然vllm支持但与fschat框架不兼容
# "bloom": "bigscience/bloom",
# "bloomz": "bigscience/bloomz",
# "bloomz-560m": "bigscience/bloomz-560m",
# "bloomz-7b1": "bigscience/bloomz-7b1",
# "bloomz-1b7": "bigscience/bloomz-1b7",
"internlm-7b": "internlm/internlm-7b", "internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b", "internlm-chat-7b": "internlm/internlm-chat-7b",
"internlm2-chat-7b": "internlm/Models/internlm2-chat-7b",
"internlm2-chat-20b": "internlm/Models/internlm2-chat-20b",
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"falcon-7b": "tiiuae/falcon-7b", "falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b", "falcon-40b": "tiiuae/falcon-40b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b", "falcon-rw-7b": "tiiuae/falcon-rw-7b",
@ -294,8 +275,6 @@ VLLM_MODEL_DICT = {
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b", "dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b", "stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"open_llama_13b": "openlm-research/open_llama_13b", "open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3", "vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
"koala": "young-geng/koala", "koala": "young-geng/koala",
@ -305,37 +284,12 @@ VLLM_MODEL_DICT = {
"opt-66b": "facebook/opt-66b", "opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b", "opt-iml-max-30b": "facebook/opt-iml-max-30b",
"Qwen-1_8B": "Qwen/Qwen-1_8B",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
"Qwen-72B": "Qwen/Qwen-72B",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
} }
# 你认为支持Agent能力的模型可以在这里添加添加后不会出现可视化界面的警告
# 经过我们测试原生支持Agent的模型仅有以下几个
SUPPORT_AGENT_MODEL = [ SUPPORT_AGENT_MODEL = [
"azure-api", "azure-api",
"openai-api", "openai-api",
"qwen-api", "qwen-api",
"Qwen", "Qwen",
"chatglm3", "chatglm3",
"xinghuo-api",
] ]

View File

@ -128,6 +128,9 @@ FSCHAT_MODEL_WORKERS = {
"tiangong-api": { "tiangong-api": {
"port": 21009, "port": 21009,
}, },
"gemini-api": {
"port": 21012,
},
} }
# fastchat multi model worker server # fastchat multi model worker server

View File

@ -1,2 +1,4 @@
from .mypdfloader import RapidOCRPDFLoader from .mypdfloader import RapidOCRPDFLoader
from .myimgloader import RapidOCRLoader from .myimgloader import RapidOCRLoader
from .mydocloader import RapidOCRDocLoader
from .mypptloader import RapidOCRPPTLoader

View File

@ -0,0 +1,71 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm
class RapidOCRDocLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def doc2text(filepath):
from docx.table import _Cell, Table
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.text.paragraph import Paragraph
from docx import Document, ImagePart
from PIL import Image
from io import BytesIO
import numpy as np
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
doc = Document(filepath)
resp = ""
def iter_block_items(parent):
from docx.document import Document
if isinstance(parent, Document):
parent_elm = parent.element.body
elif isinstance(parent, _Cell):
parent_elm = parent._tc
else:
raise ValueError("RapidOCRDocLoader parse fail")
for child in parent_elm.iterchildren():
if isinstance(child, CT_P):
yield Paragraph(child, parent)
elif isinstance(child, CT_Tbl):
yield Table(child, parent)
b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables),
desc="RapidOCRDocLoader block index: 0")
for i, block in enumerate(iter_block_items(doc)):
b_unit.set_description(
"RapidOCRDocLoader block index: {}".format(i))
b_unit.refresh()
if isinstance(block, Paragraph):
resp += block.text.strip() + "\n"
images = block._element.xpath('.//pic:pic') # 获取所有图片
for image in images:
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
if isinstance(part, ImagePart):
image = Image.open(BytesIO(part._blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif isinstance(block, Table):
for row in block.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
resp += paragraph.text.strip() + "\n"
b_unit.update(1)
return resp
text = doc2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == '__main__':
loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx")
docs = loader.load()
print(docs)

View File

@ -1,5 +1,6 @@
from typing import List from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader from langchain.document_loaders.unstructured import UnstructuredFileLoader
from configs import PDF_OCR_THRESHOLD
from document_loaders.ocr import get_ocr from document_loaders.ocr import get_ocr
import tqdm import tqdm
@ -15,7 +16,6 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0") b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0")
for i, page in enumerate(doc): for i, page in enumerate(doc):
# 更新描述 # 更新描述
b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i)) b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i))
# 立即显示进度条更新结果 # 立即显示进度条更新结果
@ -24,14 +24,20 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
text = page.get_text("") text = page.get_text("")
resp += text + "\n" resp += text + "\n"
img_list = page.get_images() img_list = page.get_image_info(xrefs=True)
for img in img_list: for img in img_list:
pix = fitz.Pixmap(doc, img[0]) if xref := img.get("xref"):
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1) bbox = img["bbox"]
result, _ = ocr(img_array) # 检查图片尺寸是否超过设定的阈值
if result: if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0]
ocr_result = [line[1] for line in result] or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]):
resp += "\n".join(ocr_result) continue
pix = fitz.Pixmap(doc, xref)
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
result, _ = ocr(img_array)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
# 更新进度 # 更新进度
b_unit.update(1) b_unit.update(1)

View File

@ -0,0 +1,59 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm
class RapidOCRPPTLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def ppt2text(filepath):
from pptx import Presentation
from PIL import Image
import numpy as np
from io import BytesIO
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
prs = Presentation(filepath)
resp = ""
def extract_text(shape):
nonlocal resp
if shape.has_text_frame:
resp += shape.text.strip() + "\n"
if shape.has_table:
for row in shape.table.rows:
for cell in row.cells:
for paragraph in cell.text_frame.paragraphs:
resp += paragraph.text.strip() + "\n"
if shape.shape_type == 13: # 13 表示图片
image = Image.open(BytesIO(shape.image.blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif shape.shape_type == 6: # 6 表示组合
for child_shape in shape.shapes:
extract_text(child_shape)
b_unit = tqdm.tqdm(total=len(prs.slides),
desc="RapidOCRPPTLoader slide index: 1")
# 遍历所有幻灯片
for slide_number, slide in enumerate(prs.slides, start=1):
b_unit.set_description(
"RapidOCRPPTLoader slide index: {}".format(slide_number))
b_unit.refresh()
sorted_shapes = sorted(slide.shapes,
key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历
for shape in sorted_shapes:
extract_text(shape)
b_unit.update(1)
return resp
text = ppt2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == '__main__':
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx")
docs = loader.load()
print(docs)

View File

@ -7,31 +7,35 @@
保存的模型的位置位于原本嵌入模型的目录下模型的名称为原模型名称+Merge_Keywords_时间戳 保存的模型的位置位于原本嵌入模型的目录下模型的名称为原模型名称+Merge_Keywords_时间戳
''' '''
import sys import sys
sys.path.append("..") sys.path.append("..")
import os
import torch
from datetime import datetime from datetime import datetime
from configs import ( from configs import (
MODEL_PATH, MODEL_PATH,
EMBEDDING_MODEL, EMBEDDING_MODEL,
EMBEDDING_KEYWORD_FILE, EMBEDDING_KEYWORD_FILE,
) )
import os
import torch
from safetensors.torch import save_model from safetensors.torch import save_model
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from langchain_core._api import deprecated
@deprecated(
since="0.3.0",
message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def get_keyword_embedding(bert_model, tokenizer, key_words): def get_keyword_embedding(bert_model, tokenizer, key_words):
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True) tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
# No need to manually convert to tensor as we've set return_tensors="pt"
input_ids = tokenizer_output['input_ids'] input_ids = tokenizer_output['input_ids']
# Remove the first and last token for each sequence in the batch
input_ids = input_ids[:, 1:-1] input_ids = input_ids[:, 1:-1]
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids) keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
keyword_embedding = torch.mean(keyword_embedding, 1) keyword_embedding = torch.mean(keyword_embedding, 1)
return keyword_embedding return keyword_embedding
@ -47,14 +51,11 @@ def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", out
bert_model = word_embedding_model.auto_model bert_model = word_embedding_model.auto_model
tokenizer = word_embedding_model.tokenizer tokenizer = word_embedding_model.tokenizer
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words) key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
# key_words_embedding = st_model.encode(key_words)
embedding_weight = bert_model.embeddings.word_embeddings.weight embedding_weight = bert_model.embeddings.word_embeddings.weight
embedding_weight_len = len(embedding_weight) embedding_weight_len = len(embedding_weight)
tokenizer.add_tokens(key_words) tokenizer.add_tokens(key_words)
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
# key_words_embedding_tensor = torch.from_numpy(key_words_embedding)
embedding_weight = bert_model.embeddings.word_embeddings.weight embedding_weight = bert_model.embeddings.word_embeddings.weight
with torch.no_grad(): with torch.no_grad():
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
@ -76,46 +77,3 @@ def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time) output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
output_model_path = os.path.join(model_parent_directory, output_model_name) output_model_path = os.path.join(model_parent_directory, output_model_name)
add_keyword_to_model(model_name, keyword_file, output_model_path) add_keyword_to_model(model_name, keyword_file, output_model_path)
if __name__ == '__main__':
add_keyword_to_embedding_model(EMBEDDING_KEYWORD_FILE)
# input_model_name = ""
# output_model_path = ""
# # 以下为加入关键字前后tokenizer的测试用例对比
# def print_token_ids(output, tokenizer, sentences):
# for idx, ids in enumerate(output['input_ids']):
# print(f'sentence={sentences[idx]}')
# print(f'ids={ids}')
# for id in ids:
# decoded_id = tokenizer.decode(id)
# print(f' {decoded_id}->{id}')
#
# sentences = [
# '数据科学与大数据技术',
# 'Langchain-Chatchat'
# ]
#
# st_no_keywords = SentenceTransformer(input_model_name)
# tokenizer_without_keywords = st_no_keywords.tokenizer
# print("===== tokenizer with no keywords added =====")
# output = tokenizer_without_keywords(sentences)
# print_token_ids(output, tokenizer_without_keywords, sentences)
# print(f'-------- embedding with no keywords added -----')
# embeddings = st_no_keywords.encode(sentences)
# print(embeddings)
#
# print("--------------------------------------------")
# print("--------------------------------------------")
# print("--------------------------------------------")
#
# st_with_keywords = SentenceTransformer(output_model_path)
# tokenizer_with_keywords = st_with_keywords.tokenizer
# print("===== tokenizer with keyword added =====")
# output = tokenizer_with_keywords(sentences)
# print_token_ids(output, tokenizer_with_keywords, sentences)
#
# print(f'-------- embedding with keywords added -----')
# embeddings = st_with_keywords.encode(sentences)
# print(embeddings)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 266 KiB

BIN
img/qr_code_85.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 272 KiB

View File

@ -6,7 +6,6 @@ from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
import nltk import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from datetime import datetime from datetime import datetime
import sys
if __name__ == "__main__": if __name__ == "__main__":
@ -50,11 +49,11 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"-i", "-i",
"--increament", "--increment",
action="store_true", action="store_true",
help=(''' help=('''
update vector store for files exist in local folder and not exist in database. update vector store for files exist in local folder and not exist in database.
use this option if you want to create vectors increamentally. use this option if you want to create vectors incrementally.
''' '''
) )
) )
@ -100,7 +99,7 @@ if __name__ == "__main__":
if args.clear_tables: if args.clear_tables:
reset_tables() reset_tables()
print("database talbes reseted") print("database tables reset")
if args.recreate_vs: if args.recreate_vs:
create_tables() create_tables()
@ -110,8 +109,8 @@ if __name__ == "__main__":
import_from_db(args.import_db) import_from_db(args.import_db)
elif args.update_in_db: elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
elif args.increament: elif args.increment:
folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model) folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model)
elif args.prune_db: elif args.prune_db:
prune_db_docs(args.kb_name) prune_db_docs(args.kb_name)
elif args.prune_folder: elif args.prune_folder:

View File

@ -1,6 +1,5 @@
# API requirements # API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/
torch~=2.1.2 torch~=2.1.2
torchvision~=0.16.2 torchvision~=0.16.2
torchaudio~=2.1.2 torchaudio~=2.1.2
@ -8,30 +7,30 @@ xformers==0.0.23.post1
transformers==4.36.2 transformers==4.36.2
sentence_transformers==2.2.2 sentence_transformers==2.2.2
langchain==0.0.352 langchain==0.0.354
langchain-experimental==0.0.47 langchain-experimental==0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.34 fschat==0.2.35
openai~=1.6.0 openai~=1.7.1
fastapi>=0.105 fastapi~=0.108.0
sse_starlette sse_starlette==1.8.2
nltk>=3.8.1 nltk>=3.8.1
uvicorn>=0.24.0.post1 uvicorn>=0.24.0.post1
starlette~=0.27.0 starlette~=0.32.0
unstructured[all-docs]==0.11.0 unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
accelerate==0.24.1 accelerate~=0.24.1
spacy~=3.7.2 spacy~=3.7.2
PyMuPDF~=1.23.8 PyMuPDF~=1.23.8
rapidocr_onnxruntime==1.3.8 rapidocr_onnxruntime==1.3.8
requests>=2.31.0 requests~=2.31.0
pathlib>=1.0.1 pathlib~=1.0.1
pytest>=7.4.3 pytest~=7.4.3
numexpr>=2.8.6 # max version for py38 numexpr~=2.8.6 # max version for py38
strsimpy>=0.2.1 strsimpy~=0.2.1
markdownify>=0.11.6 markdownify~=0.11.6
tiktoken~=0.5.2 tiktoken~=0.5.2
tqdm>=4.66.1 tqdm>=4.66.1
websockets>=12.0 websockets>=12.0
@ -39,22 +38,18 @@ numpy~=1.24.4
pandas~=2.0.3 pandas~=2.0.3
einops>=0.7.0 einops>=0.7.0
transformers_stream_generator==0.0.4 transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux" vllm==0.2.7; sys_platform == "linux"
httpx[brotli,http2,socks]==0.25.2
llama-index
# optional document loaders # optional document loaders
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files #rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
beautifulsoup4~=4.12.2 # for .mhtml files beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2 pysrt~=1.1.2
# Online api libs dependencies # Online api libs dependencies
# zhipuAI sdk is not supported on our platform, so use http instead
zhipuai>=1.0.7, <=2.0.0 # zhipu dashscope==1.13.6 # qwen
dashscope>=1.13.6 # qwen
# volcengine>=1.0.119 # fangzhou # volcengine>=1.0.119 # fangzhou
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store
@ -64,16 +59,18 @@ dashscope>=1.13.6 # qwen
# Agent and Search Tools # Agent and Search Tools
arxiv>=2.0.0 arxiv~=2.1.0
youtube-search>=2.1.2 youtube-search~=2.1.2
duckduckgo-search>=3.9.9 duckduckgo-search~=3.9.9
metaphor-python>=0.1.23 metaphor-python~=0.1.23
# WebUI requirements # WebUI requirements
streamlit~=1.29.0 # do remember to add streamlit to environment variables if you use windows streamlit==1.30.0
streamlit-option-menu>=0.3.6 streamlit-option-menu==0.3.6
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11 streamlit-chatbox==1.1.11
streamlit-modal>=0.1.0 streamlit-modal==0.1.0
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid==0.3.4.post3
watchdog>=3.0.0 httpx==0.26.0
watchdog==3.0.0

View File

@ -1,6 +1,3 @@
# API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/
torch~=2.1.2 torch~=2.1.2
torchvision~=0.16.2 torchvision~=0.16.2
torchaudio~=2.1.2 torchaudio~=2.1.2
@ -8,52 +5,52 @@ xformers==0.0.23.post1
transformers==4.36.2 transformers==4.36.2
sentence_transformers==2.2.2 sentence_transformers==2.2.2
langchain==0.0.352 langchain==0.0.354
langchain-experimental==0.0.47 langchain-experimental==0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.34 fschat==0.2.35
openai~=1.6.0 openai~=1.7.1
fastapi>=0.105 fastapi~=0.108.0
sse_starlette sse_starlette==1.8.2
nltk>=3.8.1 nltk>=3.8.1
uvicorn>=0.24.0.post1 uvicorn>=0.24.0.post1
starlette~=0.27.0 starlette~=0.32.0
unstructured[all-docs]==0.11.0 unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
accelerate==0.24.1 accelerate~=0.24.1
spacy~=3.7.2 spacy~=3.7.2
PyMuPDF~=1.23.8 PyMuPDF~=1.23.8
rapidocr_onnxruntime~=1.3.8 rapidocr_onnxruntime==1.3.8
requests>=2.31.0 requests~=2.31.0
pathlib>=1.0.1 pathlib~=1.0.1
pytest>=7.4.3 pytest~=7.4.3
numexpr>=2.8.6 # max version for py38 numexpr~=2.8.6 # max version for py38
strsimpy>=0.2.1 strsimpy~=0.2.1
markdownify>=0.11.6 markdownify~=0.11.6
tiktoken~=0.5.2 tiktoken~=0.5.2
tqdm>=4.66.1 tqdm>=4.66.1
websockets>=12.0 websockets>=12.0
numpy~=1.26.2 numpy~=1.24.4
pandas~=2.1.4 pandas~=2.0.3
einops>=0.7.0 einops>=0.7.0
transformers_stream_generator==0.0.4 transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux" vllm==0.2.7; sys_platform == "linux"
httpx[brotli,http2,socks]~=0.25.2 httpx==0.26.0
llama-index
# optional document loaders # optional document loaders
rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files # rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
beautifulsoup4~=4.12.2 # for .mhtml files beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2 pysrt~=1.1.2
# Online api libs dependencies # Online api libs dependencies
zhipuai>=1.0.7<=2.0.0 # zhipu # zhipuAI sdk is not supported on our platform, so use http instead
dashscope>=1.13.6 # qwen dashscope==1.13.6 # qwen
# volcengine>=1.0.119 # fangzhou # volcengine>=1.0.119 # fangzhou
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store
@ -63,7 +60,7 @@ dashscope>=1.13.6 # qwen
# Agent and Search Tools # Agent and Search Tools
arxiv>=2.0.0 arxiv~=2.1.0
youtube-search>=2.1.2 youtube-search~=2.1.2
duckduckgo-search>=3.9.9 duckduckgo-search~=3.9.9
metaphor-python>=0.1.23 metaphor-python~=0.1.23

View File

@ -1,60 +1,44 @@
# API requirements # API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/ langchain==0.0.354
# torch~=2.1.2
# torchvision~=0.16.2
# torchaudio~=2.1.2
# xformers==0.0.23.post1
# transformers==4.36.2
# sentence_transformers==2.2.2
langchain==0.0.352
langchain-experimental==0.0.47 langchain-experimental==0.0.47
pydantic==1.10.13 pydantic==1.10.13
fschat==0.2.34 fschat==0.2.35
openai~=1.6.0 openai~=1.7.1
fastapi>=0.105 fastapi~=0.108.0
sse_starlette sse_starlette==1.8.2
nltk>=3.8.1 nltk>=3.8.1
uvicorn>=0.24.0.post1 uvicorn>=0.24.0.post1
starlette~=0.27.0 starlette~=0.32.0
unstructured[docx,csv]==0.11.0 # add pdf if need unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32' python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19 SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus faiss-cpu~=1.7.4
# accelerate==0.24.1 requests~=2.31.0
# spacy~=3.7.2 pathlib~=1.0.1
# PyMuPDF~=1.23.8 pytest~=7.4.3
# rapidocr_onnxruntime~=1.3.8 numexpr~=2.8.6 # max version for py38
requests>=2.31.0 strsimpy~=0.2.1
pathlib>=1.0.1 markdownify~=0.11.6
pytest>=7.4.3 tiktoken~=0.5.2
numexpr>=2.8.6 # max version for py38
strsimpy>=0.2.1
markdownify>=0.11.6
# tiktoken~=0.5.2
tqdm>=4.66.1 tqdm>=4.66.1
websockets>=12.0 websockets>=12.0
numpy~=1.26.2 numpy~=1.24.4
pandas~=2.1.4 pandas~=2.0.3
# einops>=0.7.0 einops>=0.7.0
# transformers_stream_generator==0.0.4 transformers_stream_generator==0.0.4
# vllm==0.2.6; sys_platform == "linux" vllm==0.2.7; sys_platform == "linux"
httpx[brotli,http2,socks]~=0.25.2 httpx[brotli,http2,socks]==0.25.2
requests
pathlib
pytest
# optional document loaders
rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2
# Online api libs dependencies # Online api libs dependencies
zhipuai>=1.0.7<=2.0.0 # zhipu # zhipuAI sdk is not supported on our platform, so use http instead
dashscope>=1.13.6 # qwen dashscope==1.13.6
# volcengine>=1.0.119 # fangzhou # volcengine>=1.0.119
# uncomment libs if you want to use corresponding vector store # uncomment libs if you want to use corresponding vector store
# pymilvus>=2.3.4 # pymilvus>=2.3.4
@ -63,17 +47,18 @@ dashscope>=1.13.6 # qwen
# Agent and Search Tools # Agent and Search Tools
arxiv>=2.0.0 arxiv~=2.1.0
youtube-search>=2.1.2 youtube-search~=2.1.2
duckduckgo-search>=3.9.9 duckduckgo-search~=3.9.9
metaphor-python>=0.1.23 metaphor-python~=0.1.23
# WebUI requirements # WebUI requirements
streamlit~=1.29.0 # do remember to add streamlit to environment variables if you use windows streamlit==1.30.0
streamlit-option-menu>=0.3.6 streamlit-option-menu==0.3.6
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11 streamlit-chatbox==1.1.11
streamlit-modal>=0.1.0 streamlit-modal==0.1.0
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid==0.3.4.post3
httpx[brotli,http2,socks]>=0.25.2 httpx==0.26.0
watchdog>=3.0.0 watchdog==3.0.0

View File

@ -1,9 +1,10 @@
# WebUI requirements # WebUI requirements
streamlit~=1.29.0 # do remember to add streamlit to environment variables if you use windows streamlit==1.30.0
streamlit-option-menu>=0.3.6 streamlit-option-menu==0.3.6
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11 streamlit-chatbox==1.1.11
streamlit-modal>=0.1.0 streamlit-modal==0.1.0
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid==0.3.4.post3
httpx[brotli,http2,socks]>=0.25.2 httpx==0.26.0
watchdog>=3.0.0 watchdog==3.0.0

View File

@ -1,22 +1,19 @@
""" """
This file is a modified version for ChatGLM3-6B the original ChatGLM3Agent.py file from the langchain repo. This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
""" """
from __future__ import annotations from __future__ import annotations
import yaml
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from typing import Any, List, Sequence, Tuple, Optional, Union
import os
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate, MessagesPlaceholder,
)
import json import json
import logging import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
from pydantic.schema import model_schema
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.agents.agent import AgentOutputParser from langchain.agents.agent import AgentOutputParser
from langchain.output_parsers import OutputFixingParser from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
@ -43,12 +40,18 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens]) first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index] text = text[:first_index]
if "tool_call" in text: if "tool_call" in text:
tool_name_end = text.find("```") action_end = text.find("```")
tool_name = text[:tool_name_end].strip() action = text[:action_end].strip()
input_para = text.split("='")[-1].split("'")[0] params_str_start = text.find("(") + 1
params_str_end = text.rfind(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
action_json = { action_json = {
"action": tool_name, "action": action,
"action_input": input_para "action_input": params
} }
else: else:
action_json = { action_json = {
@ -109,10 +112,6 @@ class StructuredGLM3ChatAgent(Agent):
else: else:
return agent_scratchpad return agent_scratchpad
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
pass
@classmethod @classmethod
def _get_default_output_parser( def _get_default_output_parser(
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
@ -121,7 +120,7 @@ class StructuredGLM3ChatAgent(Agent):
@property @property
def _stop(self) -> List[str]: def _stop(self) -> List[str]:
return ["```<observation>"] return ["<|observation|>"]
@classmethod @classmethod
def create_prompt( def create_prompt(
@ -131,44 +130,25 @@ class StructuredGLM3ChatAgent(Agent):
input_variables: Optional[List[str]] = None, input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None, memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
def tool_config_from_file(tool_name, directory="server/agent/tools/"):
"""search tool yaml and return simplified json format"""
file_path = os.path.join(directory, f"{tool_name.lower()}.yaml")
try:
with open(file_path, 'r', encoding='utf-8') as file:
tool_config = yaml.safe_load(file)
# Simplify the structure if needed
simplified_config = {
"name": tool_config.get("name", ""),
"description": tool_config.get("description", ""),
"parameters": tool_config.get("parameters", {})
}
return simplified_config
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
return None
except Exception as e:
logger.error(f"An error occurred while reading {file_path}: {e}")
return None
tools_json = [] tools_json = []
tool_names = [] tool_names = []
for tool in tools: for tool in tools:
tool_config = tool_config_from_file(tool.name) tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
if tool_config: simplified_config_langchain = {
tools_json.append(tool_config) "name": tool.name,
tool_names.append(tool.name) "description": tool.description,
"parameters": tool_schema.get("properties", {})
# Format the tools for output }
tools_json.append(simplified_config_langchain)
tool_names.append(tool.name)
formatted_tools = "\n".join([ formatted_tools = "\n".join([
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}" f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
for tool in tools_json for tool in tools_json
]) ])
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}") formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
template = prompt.format(tool_names=tool_names, template = prompt.format(tool_names=tool_names,
tools=formatted_tools, tools=formatted_tools,
history="{history}", history="None",
input="{input}", input="{input}",
agent_scratchpad="{agent_scratchpad}") agent_scratchpad="{agent_scratchpad}")
@ -225,7 +205,6 @@ def initialize_glm3_agent(
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: str = None, prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None,
memory: Optional[ConversationBufferWindowMemory] = None, memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None, agent_kwargs: Optional[dict] = None,
*, *,
@ -238,14 +217,12 @@ def initialize_glm3_agent(
llm=llm, llm=llm,
tools=tools, tools=tools,
prompt=prompt, prompt=prompt,
callback_manager=callback_manager, **agent_kwargs **agent_kwargs
) )
return AgentExecutor.from_agent_and_tools( return AgentExecutor.from_agent_and_tools(
agent=agent_obj, agent=agent_obj,
tools=tools, tools=tools,
callback_manager=callback_manager,
memory=memory, memory=memory,
tags=tags_, tags=tags_,
**kwargs, **kwargs,
) )

View File

@ -1,5 +1,3 @@
## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍
class ModelContainer: class ModelContainer:
def __init__(self): def __init__(self):
self.MODEL = None self.MODEL = None

View File

@ -3,7 +3,7 @@ from .search_knowledgebase_simple import search_knowledgebase_simple
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
from .calculate import calculate, CalculatorInput from .calculate import calculate, CalculatorInput
from .weather_check import weathercheck, WhetherSchema from .weather_check import weathercheck, WeatherInput
from .shell import shell, ShellInput from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput from .wolfram import wolfram, WolframInput

View File

@ -1,10 +0,0 @@
name: arxiv
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
parameters:
type: object
properties:
query:
type: string
description: The search query title
required:
- query

View File

@ -1,10 +0,0 @@
name: calculate
description: Useful for when you need to answer questions about simple calculations
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -1,10 +0,0 @@
name: search_internet
description: Use this tool to surf internet and get information
parameters:
type: object
properties:
query:
type: string
description: Query for Internet search
required:
- query

View File

@ -1,10 +0,0 @@
name: search_knowledgebase_complex
description: Use this tool to search local knowledgebase and get information
parameters:
type: object
properties:
query:
type: string
description: The query to be searched
required:
- query

View File

@ -1,10 +0,0 @@
name: search_youtube
description: Use this tools to search youtube videos
parameters:
type: object
properties:
query:
type: string
description: Query for Videos search
required:
- query

View File

@ -1,10 +0,0 @@
name: shell
description: Use Linux Shell to execute Linux commands
parameters:
type: object
properties:
query:
type: string
description: The command to execute
required:
- query

View File

@ -1,338 +1,25 @@
from __future__ import annotations
## 单独运行的时候需要添加
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
import requests
from typing import List, Any, Optional
from datetime import datetime
from langchain.prompts import PromptTemplate
from server.agent import model_container
from pydantic import BaseModel, Field
## 使用和风天气API查询天气
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
# key长这样这里提供了示例的key这个key没法使用你需要自己去注册和风天气的账号然后在这里填入你的key
_PROMPT_TEMPLATE = """
用户会提出一个关于天气的问题你的目标是拆分出用户问题中的区 并按照我提供的工具回答
例如 用户提出的问题是: 上海浦东未来1小时天气情况
提取的市和区是: 上海 浦东
如果用户提出的问题是: 上海未来1小时天气情况
提取的市和区是: 上海 None
请注意以下内容:
1. 如果你没有找到区的内容,则一定要使用 None 替代否则程序无法运行
2. 如果用户没有指定市 则直接返回缺少信息
问题: ${{用户的问题}}
你的回答格式应该按照下面的内容请注意格式内的```text 等标记都必须输出这是我用来提取答案的标记
```text
${{拆分的市和区中间用空格隔开}}
```
... weathercheck( )...
```output
${{提取后的答案}}
```
答案: ${{答案}}
这是一个例子
问题: 上海浦东未来1小时天气情况
```text
上海 浦东
```
...weathercheck(上海 浦东)...
```output
预报时间: 1小时后
具体时间: 今天 18:00
温度: 24°C
天气: 多云
风向: 西南风
风速: 7
湿度: 88%
降水概率: 16%
Answer: 上海浦东一小时后的天气是多云
现在这是我的问题
问题: {question}
""" """
PROMPT = PromptTemplate( 更简单的单参数输入工具实现用于查询现在天气的情况
input_variables=["question"], """
template=_PROMPT_TEMPLATE, from pydantic import BaseModel, Field
) import requests
def weather(location: str, api_key: str):
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
weather = {
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
return weather
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
def get_city_info(location, adm, key): def weathercheck(location: str):
base_url = 'https://geoapi.qweather.com/v2/city/lookup?' return weather(location, "your keys")
params = {'location': location, 'adm': adm, 'key': key} class WeatherInput(BaseModel):
response = requests.get(base_url, params=params) location: str = Field(description="City name,include city and county")
data = response.json()
return data
def format_weather_data(data, place):
hourly_forecast = data['hourly']
formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n"
for forecast in hourly_forecast:
# 将预报时间转换为datetime对象
forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z')
# 获取预报时间的时区
forecast_tz = forecast_time.tzinfo
# 获取当前时间(使用预报时间的时区)
now = datetime.now(forecast_tz)
# 计算预报日期与当前日期的差值
days_diff = (forecast_time.date() - now.date()).days
if days_diff == 0:
forecast_date_str = '今天'
elif days_diff == 1:
forecast_date_str = '明天'
elif days_diff == 2:
forecast_date_str = '后天'
else:
forecast_date_str = str(days_diff) + '天后'
forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M')
# 计算预报时间与当前时间的差值
time_diff = forecast_time - now
# 将差值转换为小时
hours_diff = time_diff.total_seconds() // 3600
if hours_diff < 1:
hours_diff_str = '1小时后'
elif hours_diff >= 24:
# 如果超过24小时转换为天数
days_diff = hours_diff // 24
hours_diff_str = str(int(days_diff)) + ''
else:
hours_diff_str = str(int(hours_diff)) + '小时'
# 将预报时间和当前时间的差值添加到输出中
formatted_data += '预报时间: ' + forecast_time_str + ' 距离现在有: ' + hours_diff_str + '\n'
formatted_data += '温度: ' + forecast['temp'] + '°C\n'
formatted_data += '天气: ' + forecast['text'] + '\n'
formatted_data += '风向: ' + forecast['windDir'] + '\n'
formatted_data += '风速: ' + forecast['windSpeed'] + '\n'
formatted_data += '湿度: ' + forecast['humidity'] + '%\n'
formatted_data += '降水概率: ' + forecast['pop'] + '%\n'
# formatted_data += '降水量: ' + forecast['precip'] + 'mm\n'
formatted_data += '\n'
return formatted_data
def get_weather(key, location_id, place):
url = "https://devapi.qweather.com/v7/weather/24h?"
params = {
'location': location_id,
'key': key,
}
response = requests.get(url, params=params)
data = response.json()
return format_weather_data(data, place)
def split_query(query):
parts = query.split()
adm = parts[0]
if len(parts) == 1:
return adm, adm
location = parts[1] if parts[1] != 'None' else adm
return location, adm
def weather(query):
location, adm = split_query(query)
key = KEY
if key == "":
return "请先在代码中填入和风天气API Key"
try:
city_info = get_city_info(location=location, adm=adm, key=key)
location_id = city_info['location'][0]['id']
place = adm + "" + location + ""
weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "以上是查询到的天气信息,请你查收\n"
except KeyError:
try:
city_info = get_city_info(location=adm, adm=adm, key=key)
location_id = city_info['location'][0]['id']
place = adm + ""
weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n"
except KeyError:
return "输入的地区不存在,无法提供天气预报"
class LLMWeatherChain(Chain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMWeatherChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, expression: str) -> str:
try:
output = weather(expression)
except Exception as e:
output = "输入的信息有误,请再次尝试"
return output
def _process_llm_result(
self, llm_output: str, run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
llm_output = self.llm_chain.predict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
llm_output = await self.llm_chain.apredict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, _run_manager)
@property
def _chain_type(self) -> str:
return "llm_weather_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMWeatherChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def weathercheck(query: str):
model = model_container.MODEL
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_weather.run(query)
return ans
class WhetherSchema(BaseModel):
location: str = Field(description="应该是一个地区的名称,用空格隔开,例如:上海 浦东,如果没有区的信息,可以只输入上海")
if __name__ == '__main__':
result = weathercheck("苏州姑苏区今晚热不热?")

View File

@ -1,10 +0,0 @@
name: weather_check
description: Use Weather API to get weather information
parameters:
type: object
properties:
query:
type: string
description: City name,include city and county,like "厦门市思明区"
required:
- query

View File

@ -1,10 +0,0 @@
name: wolfram
description: Useful for when you need to calculate difficult math formulas
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -1,8 +1,6 @@
from langchain.tools import Tool from langchain.tools import Tool
from server.agent.tools import * from server.agent.tools import *
## 请注意如果你是为了使用AgentLM在这里你应该使用英文版本。
tools = [ tools = [
Tool.from_function( Tool.from_function(
func=calculate, func=calculate,
@ -20,7 +18,7 @@ tools = [
func=weathercheck, func=weathercheck,
name="weather_check", name="weather_check",
description="", description="",
args_schema=WhetherSchema, args_schema=WeatherInput,
), ),
Tool.from_function( Tool.from_function(
func=shell, func=shell,

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult from langchain.schema import LLMResult

View File

@ -1,23 +1,23 @@
from langchain.memory import ConversationBufferWindowMemory import json
import asyncio
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body from fastapi import Body
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional
import asyncio
from typing import List
from server.chat.utils import History
import json
from server.agent import model_container
from server.knowledge_base.kb_service.base import get_kb_details
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from typing import AsyncIterable, Optional, List
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.knowledge_base.kb_service.base import get_kb_details
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from server.chat.utils import History
from server.agent import model_container
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([], history: List[History] = Body([],
@ -33,7 +33,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default", prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
): ):
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
@ -55,12 +54,10 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
callbacks=[callback], callbacks=[callback],
) )
## 传入全局变量来实现agent调用
kb_list = {x["kb_name"]: x for x in get_kb_details()} kb_list = {x["kb_name"]: x for x in get_kb_details()}
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()} model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
if Agent_MODEL: if Agent_MODEL:
## 如果有指定使用Agent模型来完成任务
model_agent = get_ChatOpenAI( model_agent = get_ChatOpenAI(
model_name=Agent_MODEL, model_name=Agent_MODEL,
temperature=temperature, temperature=temperature,
@ -79,15 +76,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
) )
output_parser = CustomOutputParser() output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent) llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
# 把history转成agent的memory
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2) memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history: for message in history:
# 检查消息的角色
if message.role == 'user': if message.role == 'user':
# 添加用户消息
memory.chat_memory.add_user_message(message.content) memory.chat_memory.add_user_message(message.content)
else: else:
# 添加AI消息
memory.chat_memory.add_ai_message(message.content) memory.chat_memory.add_ai_message(message.content)
if "chatglm3" in model_container.MODEL.model_name: if "chatglm3" in model_container.MODEL.model_name:
@ -95,7 +88,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
llm=model, llm=model,
tools=tools, tools=tools,
callback_manager=None, callback_manager=None,
# Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent.
prompt=prompt_template, prompt=prompt_template,
input_variables=["input", "intermediate_steps", "history"], input_variables=["input", "intermediate_steps", "history"],
memory=memory, memory=memory,
@ -155,7 +147,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
answer = "" answer = ""
final_answer = "" final_answer = ""
async for chunk in callback.aiter(): async for chunk in callback.aiter():
# Use server-sent-events to stream the response
data = json.loads(chunk) data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete: if data["status"] == Status.start or data["status"] == Status.complete:
continue continue
@ -181,7 +172,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
await task await task
return EventSourceResponse(agent_chat_iterator(query=query, return EventSourceResponse(agent_chat_iterator(query=query,
history=history, history=history,
model_name=model_name, model_name=model_name,
prompt_name=prompt_name), prompt_name=prompt_name),
) )

View File

@ -1,23 +1,23 @@
from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY, from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from sse_starlette import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Optional, Dict
from server.chat.utils import History
from langchain.docstore.document import Document from langchain.docstore.document import Document
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from sse_starlette import EventSourceResponse
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from server.chat.utils import History
from typing import AsyncIterable
import asyncio
import json import json
from typing import List, Optional, Dict
from strsimpy.normalized_levenshtein import NormalizedLevenshtein from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify from markdownify import markdownify
@ -38,11 +38,11 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
def metaphor_search( def metaphor_search(
text: str, text: str,
result_len: int = SEARCH_ENGINE_TOP_K, result_len: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False, split_result: bool = False,
chunk_size: int = 500, chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE, chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]: ) -> List[Dict]:
from metaphor_python import Metaphor from metaphor_python import Metaphor
@ -58,13 +58,13 @@ def metaphor_search(
# metaphor 返回的内容都是长文本,需要分词再检索 # metaphor 返回的内容都是长文本,需要分词再检索
if split_result: if split_result:
docs = [Document(page_content=x.extract, docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title}) metadata={"link": x.url, "title": x.title})
for x in contents] for x in contents]
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "], text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap) chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs) splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档 # 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len: if len(splitted_docs) > result_len:
normal = NormalizedLevenshtein() normal = NormalizedLevenshtein()
@ -74,13 +74,13 @@ def metaphor_search(
splitted_docs = splitted_docs[:result_len] splitted_docs = splitted_docs[:result_len]
docs = [{"snippet": x.page_content, docs = [{"snippet": x.page_content,
"link": x.metadata["link"], "link": x.metadata["link"],
"title": x.metadata["title"]} "title": x.metadata["title"]}
for x in splitted_docs] for x in splitted_docs]
else: else:
docs = [{"snippet": x.extract, docs = [{"snippet": x.extract,
"link": x.url, "link": x.url,
"title": x.title} "title": x.title}
for x in contents] for x in contents]
return docs return docs
@ -113,25 +113,27 @@ async def lookup_search_engine(
docs = search_result2docs(results) docs = search_result2docs(results)
return docs return docs
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]), async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([], history: List[History] = Body([],
description="历史对话", description="历史对话",
examples=[[ examples=[[
{"role": "user", {"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"}, "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", {"role": "assistant",
"content": "虎头虎脑"}]] "content": "虎头虎脑"}]]
), ),
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"), max_tokens: Optional[int] = Body(None,
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), description="限制LLM生成Token数量默认None代表模型最大值"),
split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎") prompt_name: str = Body("default",
): description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False,
description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
):
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -198,9 +200,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
await task await task
return EventSourceResponse(search_engine_chat_iterator(query=query, return EventSourceResponse(search_engine_chat_iterator(query=query,
search_engine_name=search_engine_name, search_engine_name=search_engine_name,
top_k=top_k, top_k=top_k,
history=history, history=history,
model_name=model_name, model_name=model_name,
prompt_name=prompt_name), prompt_name=prompt_name),
) )

View File

@ -83,7 +83,7 @@ def add_file_to_db(session,
kb_file: KnowledgeFile, kb_file: KnowledgeFile,
docs_count: int = 0, docs_count: int = 0,
custom_docs: bool = False, custom_docs: bool = False,
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...] doc_infos: List[Dict] = [], # 形式:[{"id": str, "metadata": dict}, ...]
): ):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb: if kb:

View File

@ -24,7 +24,7 @@ from server.knowledge_base.utils import (
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,
) )
from typing import List, Union, Dict, Optional from typing import List, Union, Dict, Optional, Tuple
from server.embeddings_api import embed_texts, aembed_texts, embed_documents from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId from server.knowledge_base.model.kb_document_model import DocumentWithVSId
@ -261,7 +261,7 @@ class KBService(ABC):
query: str, query: str,
top_k: int, top_k: int,
score_threshold: float, score_threshold: float,
) -> List[Document]: ) -> List[Tuple[Document, float]]:
""" """
搜索知识库子类实自己逻辑 搜索知识库子类实自己逻辑
""" """

View File

@ -6,6 +6,7 @@ from langchain.schema import Document
from langchain.vectorstores.elasticsearch import ElasticsearchStore from langchain.vectorstores.elasticsearch import ElasticsearchStore
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile
from server.utils import load_local_embeddings from server.utils import load_local_embeddings
from elasticsearch import Elasticsearch,BadRequestError from elasticsearch import Elasticsearch,BadRequestError
from configs import logger from configs import logger
@ -15,7 +16,7 @@ class ESKBService(KBService):
def do_init(self): def do_init(self):
self.kb_path = self.get_kb_path(self.kb_name) self.kb_path = self.get_kb_path(self.kb_name)
self.index_name = self.kb_path.split("/")[-1] self.index_name = os.path.split(self.kb_path)[-1]
self.IP = kbs_config[self.vs_type()]['host'] self.IP = kbs_config[self.vs_type()]['host']
self.PORT = kbs_config[self.vs_type()]['port'] self.PORT = kbs_config[self.vs_type()]['port']
self.user = kbs_config[self.vs_type()].get("user",'') self.user = kbs_config[self.vs_type()].get("user",'')
@ -38,7 +39,16 @@ class ESKBService(KBService):
raise e raise e
try: try:
# 首先尝试通过es_client_python创建 # 首先尝试通过es_client_python创建
self.es_client_python.indices.create(index=self.index_name) mappings = {
"properties": {
"dense_vector": {
"type": "dense_vector",
"dims": self.dims_length,
"index": True
}
}
}
self.es_client_python.indices.create(index=self.index_name, mappings=mappings)
except BadRequestError as e: except BadRequestError as e:
logger.error("创建索引失败,重新") logger.error("创建索引失败,重新")
logger.error(e) logger.error(e)
@ -80,9 +90,9 @@ class ESKBService(KBService):
except Exception as e: except Exception as e:
logger.error("创建索引失败...") logger.error("创建索引失败...")
logger.error(e) logger.error(e)
# raise e # raise e
@staticmethod @staticmethod
def get_kb_path(knowledge_base_name: str): def get_kb_path(knowledge_base_name: str):
@ -220,7 +230,12 @@ class ESKBService(KBService):
shutil.rmtree(self.kb_path) shutil.rmtree(self.kb_path)
if __name__ == '__main__':
esKBService = ESKBService("test")
#esKBService.clear_vs()
#esKBService.create_kb()
esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test"))
print(esKBService.search_docs("如何启动api服务"))

View File

@ -7,7 +7,7 @@ from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafe
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import torch_gc from server.utils import torch_gc
from langchain.docstore.document import Document from langchain.docstore.document import Document
from typing import List, Dict, Optional from typing import List, Dict, Optional, Tuple
class FaissKBService(KBService): class FaissKBService(KBService):
@ -61,7 +61,7 @@ class FaissKBService(KBService):
query: str, query: str,
top_k: int, top_k: int,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
) -> List[Document]: ) -> List[Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model) embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query) embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs: with self.load_vector_store().acquire() as vs:

View File

@ -18,13 +18,10 @@ class MilvusKBService(KBService):
from pymilvus import Collection from pymilvus import Collection
return Collection(milvus_name) return Collection(milvus_name)
# def save_vector_store(self):
# if self.milvus.col:
# self.milvus.col.flush()
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = [] result = []
if self.milvus.col: if self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"]) data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list: for data in data_list:
text = data.pop("text") text = data.pop("text")
@ -53,7 +50,7 @@ class MilvusKBService(KBService):
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model), self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"), connection_args=kbs_config.get("milvus"),
index_params=kbs_config.ge("milvus_kwargs")["index_params"], index_params=kbs_config.get("milvus_kwargs")["index_params"],
search_params=kbs_config.get("milvus_kwargs")["search_params"] search_params=kbs_config.get("milvus_kwargs")["search_params"]
) )

View File

@ -11,22 +11,26 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em
score_threshold_process score_threshold_process
from server.knowledge_base.utils import KnowledgeFile from server.knowledge_base.utils import KnowledgeFile
import shutil import shutil
import sqlalchemy
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
class PGKBService(KBService): class PGKBService(KBService):
pg_vector: PGVector engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
def _load_pg_vector(self): def _load_pg_vector(self):
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model), self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN, distance_strategy=DistanceStrategy.EUCLIDEAN,
connection=PGKBService.engine,
connection_string=kbs_config.get("pg").get("connection_uri")) connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.pg_vector.connect() as connect: with Session(PGKBService.engine) as session:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids") stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
results = [Document(page_content=row[0], metadata=row[1]) for row in results = [Document(page_content=row[0], metadata=row[1]) for row in
connect.execute(stmt, parameters={'ids': ids}).fetchall()] session.execute(stmt, {'ids': ids}).fetchall()]
return results return results
# TODO: # TODO:
@ -43,8 +47,8 @@ class PGKBService(KBService):
return SupportedVSType.PG return SupportedVSType.PG
def do_drop_kb(self): def do_drop_kb(self):
with self.pg_vector.connect() as connect: with Session(PGKBService.engine) as session:
connect.execute(text(f''' session.execute(text(f'''
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录 -- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
DELETE FROM langchain_pg_embedding DELETE FROM langchain_pg_embedding
WHERE collection_id IN ( WHERE collection_id IN (
@ -53,11 +57,10 @@ class PGKBService(KBService):
-- 删除 langchain_pg_collection 表中 记录 -- 删除 langchain_pg_collection 表中 记录
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}'; DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
''')) '''))
connect.commit() session.commit()
shutil.rmtree(self.kb_path) shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float): def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_pg_vector()
embed_func = EmbeddingsFunAdapter(self.embed_model) embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query) embeddings = embed_func.embed_query(query)
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
@ -69,13 +72,13 @@ class PGKBService(KBService):
return doc_infos return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect: with Session(PGKBService.engine) as session:
filepath = kb_file.filepath.replace('\\', '\\\\') filepath = kb_file.filepath.replace('\\', '\\\\')
connect.execute( session.execute(
text( text(
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace( ''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
"filepath", filepath))) "filepath", filepath)))
connect.commit() session.commit()
def do_clear_vs(self): def do_clear_vs(self):
self.pg_vector.delete_collection() self.pg_vector.delete_collection()

View File

@ -16,13 +16,10 @@ class ZillizKBService(KBService):
from pymilvus import Collection from pymilvus import Collection
return Collection(zilliz_name) return Collection(zilliz_name)
# def save_vector_store(self):
# if self.zilliz.col:
# self.zilliz.col.flush()
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = [] result = []
if self.zilliz.col: if self.zilliz.col:
# ids = [int(id) for id in ids] # for zilliz if needed #pr 2725
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"]) data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list: for data in data_list:
text = data.pop("text") text = data.pop("text")
@ -50,8 +47,7 @@ class ZillizKBService(KBService):
def _load_zilliz(self): def _load_zilliz(self):
zilliz_args = kbs_config.get("zilliz") zilliz_args = kbs_config.get("zilliz")
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model), self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, connection_args=zilliz_args) collection_name=self.kb_name, connection_args=zilliz_args)
def do_init(self): def do_init(self):
self._load_zilliz() self._load_zilliz()
@ -95,9 +91,7 @@ class ZillizKBService(KBService):
if __name__ == '__main__': if __name__ == '__main__':
from server.db.base import Base, engine from server.db.base import Base, engine
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
zillizService = ZillizKBService("test") zillizService = ZillizKBService("test")

View File

@ -84,7 +84,7 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
def folder2db( def folder2db(
kb_names: List[str], kb_names: List[str],
mode: Literal["recreate_vs", "update_in_db", "increament"], mode: Literal["recreate_vs", "update_in_db", "increment"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
chunk_size: int = CHUNK_SIZE, chunk_size: int = CHUNK_SIZE,
@ -97,7 +97,7 @@ def folder2db(
recreate_vs: recreate all vector store and fill info to database using existed files in local folder recreate_vs: recreate all vector store and fill info to database using existed files in local folder
fill_info_only(disabled): do not create vector store, fill info to db using existed files only fill_info_only(disabled): do not create vector store, fill info to db using existed files only
update_in_db: update vector store and database info using local files that existed in database only update_in_db: update vector store and database info using local files that existed in database only
increament: create vector store and database info for local files that not existed in database only increment: create vector store and database info for local files that not existed in database only
""" """
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]): def files2vs(kb_name: str, kb_files: List[KnowledgeFile]):
@ -142,7 +142,7 @@ def folder2db(
files2vs(kb_name, kb_files) files2vs(kb_name, kb_files)
kb.save_vector_store() kb.save_vector_store()
# 对比本地目录与数据库中的文件列表,进行增量向量化 # 对比本地目录与数据库中的文件列表,进行增量向量化
elif mode == "increament": elif mode == "increment":
db_files = kb.list_files() db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name) folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files)) files = list(set(folder_files) - set(db_files))
@ -150,7 +150,7 @@ def folder2db(
files2vs(kb_name, kb_files) files2vs(kb_name, kb_files)
kb.save_vector_store() kb.save_vector_store()
else: else:
print(f"unspported migrate mode: {mode}") print(f"unsupported migrate mode: {mode}")
def prune_db_docs(kb_names: List[str]): def prune_db_docs(kb_names: List[str]):

View File

@ -91,9 +91,14 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"JSONLoader": [".json"], "JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"], "JSONLinesLoader": [".jsonl"],
"CSVLoader": [".csv"], "CSVLoader": [".csv"],
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持 # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
"RapidOCRPDFLoader": [".pdf"], "RapidOCRPDFLoader": [".pdf"],
"RapidOCRDocLoader": ['.docx', '.doc'],
"RapidOCRPPTLoader": ['.ppt', '.pptx', ],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'], "RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.epub', '.odt','.tsv'],
"UnstructuredEmailLoader": ['.eml', '.msg'], "UnstructuredEmailLoader": ['.eml', '.msg'],
"UnstructuredEPubLoader": ['.epub'], "UnstructuredEPubLoader": ['.epub'],
"UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'], "UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'],
@ -109,7 +114,6 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredXMLLoader": ['.xml'], "UnstructuredXMLLoader": ['.xml'],
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'], "UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
"EverNoteLoader": ['.enex'], "EverNoteLoader": ['.enex'],
"UnstructuredFileLoader": ['.txt'],
} }
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
@ -141,15 +145,14 @@ def get_LoaderClass(file_extension):
if file_extension in extensions: if file_extension in extensions:
return LoaderClass return LoaderClass
# 把一些向量化共用逻辑从KnowledgeFile抽取出来等langchain支持内存文件的时候可以将非磁盘文件向量化
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
''' '''
根据loader_name和文件路径或内容返回文档加载器 根据loader_name和文件路径或内容返回文档加载器
''' '''
loader_kwargs = loader_kwargs or {} loader_kwargs = loader_kwargs or {}
try: try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]: if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader",
"RapidOCRDocLoader", "RapidOCRPPTLoader"]:
document_loaders_module = importlib.import_module('document_loaders') document_loaders_module = importlib.import_module('document_loaders')
else: else:
document_loaders_module = importlib.import_module('langchain.document_loaders') document_loaders_module = importlib.import_module('langchain.document_loaders')
@ -258,7 +261,11 @@ def make_text_splitter(
print(e) print(e)
text_splitter_module = importlib.import_module('langchain.text_splitter') text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50) text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
# If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
# text_splitter._tokenizer.max_length = 37016792
# text_splitter._tokenizer.prefer_gpu()
return text_splitter return text_splitter

View File

@ -0,0 +1,51 @@
from typing import (
TYPE_CHECKING,
Any,
Tuple
)
import sys
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
import tiktoken
class MinxChatOpenAI:
@staticmethod
def import_tiktoken() -> Any:
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install tiktoken`."
)
return tiktoken
@staticmethod
def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]:
tiktoken_ = MinxChatOpenAI.import_tiktoken()
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model(model)
except Exception as e:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding

View File

@ -8,3 +8,4 @@ from .qwen import QwenWorker
from .baichuan import BaiChuanWorker from .baichuan import BaiChuanWorker
from .azure import AzureWorker from .azure import AzureWorker
from .tiangong import TianGongWorker from .tiangong import TianGongWorker
from .gemini import GeminiWorker

View File

@ -0,0 +1,124 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import json,httpx
from typing import List, Dict
from configs import logger, log_verbose
class GeminiWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["Gemini-api"],
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs)
def create_gemini_messages(self,messages) -> json:
has_history = any(msg['role'] == 'assistant' for msg in messages)
gemini_msg = []
for msg in messages:
role = msg['role']
content = msg['content']
if role == 'system':
continue
if has_history:
if role == 'assistant':
role = "model"
transformed_msg = {"role": role, "parts": [{"text": content}]}
else:
if role == 'user':
transformed_msg = {"parts": [{"text": content}]}
gemini_msg.append(transformed_msg)
msg = dict(contents=gemini_msg)
return msg
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
data = self.create_gemini_messages(messages=params.messages)
generationConfig=dict(
temperature=params.temperature,
topK=1,
topP=1,
maxOutputTokens=4096,
stopSequences=[]
)
data['generationConfig'] = generationConfig
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"+ '?key=' + params.api_key
headers = {
'Content-Type': 'application/json',
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
logger.info(f'{self.__class__.__name__}:data: {data}')
text = ""
json_string = ""
timeout = httpx.Timeout(60.0)
client=get_httpx_client(timeout=timeout)
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()
if not line or "[DONE]" in line:
continue
json_string += line
try:
resp = json.loads(json_string)
if 'candidates' in resp:
for candidate in resp['candidates']:
content = candidate.get('content', {})
parts = content.get('parts', [])
for part in parts:
if 'text' in part:
text += part['text']
yield {
"error_code": 0,
"text": text
}
print(text)
except json.JSONDecodeError as e:
print("Failed to decode JSON:", e)
print("Invalid JSON string:", json_string)
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="You are a helpful, respectful and honest assistant.",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = GeminiWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21012",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21012)

View File

@ -28,7 +28,7 @@ class MiniMaxWorker(ApiModelWorker):
def validate_messages(self, messages: List[Dict]) -> List[Dict]: def validate_messages(self, messages: List[Dict]) -> List[Dict]:
role_maps = { role_maps = {
"user": self.user_role, "USER": self.user_role,
"assistant": self.ai_role, "assistant": self.ai_role,
"system": "system", "system": "system",
} }
@ -73,7 +73,7 @@ class MiniMaxWorker(ApiModelWorker):
with response as r: with response as r:
text = "" text = ""
for e in r.iter_text(): for e in r.iter_text():
if not e.startswith("data: "): # 真是优秀的返回 if not e.startswith("data: "):
data = { data = {
"error_code": 500, "error_code": 500,
"text": f"minimax返回错误的结果{e}", "text": f"minimax返回错误的结果{e}",
@ -140,7 +140,7 @@ class MiniMaxWorker(ApiModelWorker):
self.logger.error(f"请求 MiniMax API 时发生错误:{data}") self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data return data
i += batch_size i += batch_size
return {"code": 200, "data": embeddings} return {"code": 200, "data": result}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -84,30 +84,6 @@ class QianFanWorker(ApiModelWorker):
def do_chat(self, params: ApiChatParams) -> Dict: def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
# import qianfan
# comp = qianfan.ChatCompletion(model=params.version,
# endpoint=params.version_url,
# ak=params.api_key,
# sk=params.secret_key,)
# text = ""
# for resp in comp.do(messages=params.messages,
# temperature=params.temperature,
# top_p=params.top_p,
# stream=True):
# if resp.code == 200:
# if chunk := resp.body.get("result"):
# text += chunk
# yield {
# "error_code": 0,
# "text": text
# }
# else:
# yield {
# "error_code": resp.code,
# "text": str(resp.body),
# }
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \ BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}' '/{model_version}?access_token={access_token}'
@ -190,19 +166,19 @@ class QianFanWorker(ApiModelWorker):
i = 0 i = 0
batch_size = 10 batch_size = 10
while i < len(params.texts): while i < len(params.texts):
texts = params.texts[i:i+batch_size] texts = params.texts[i:i + batch_size]
resp = client.post(url, json={"input": texts}).json() resp = client.post(url, json={"input": texts}).json()
if "error_code" in resp: if "error_code" in resp:
data = { data = {
"code": resp["error_code"], "code": resp["error_code"],
"msg": resp["error_msg"], "msg": resp["error_msg"],
"error": { "error": {
"message": resp["error_msg"], "message": resp["error_msg"],
"type": "invalid_request_error", "type": "invalid_request_error",
"param": None, "param": None,
"code": None, "code": None,
} }
} }
self.logger.error(f"请求千帆 API 时发生错误:{data}") self.logger.error(f"请求千帆 API 时发生错误:{data}")
return data return data
else: else:

View File

@ -11,16 +11,15 @@ from typing import List, Literal, Dict
import requests import requests
class TianGongWorker(ApiModelWorker): class TianGongWorker(ApiModelWorker):
def __init__( def __init__(
self, self,
*, *,
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
model_names: List[str] = ["tiangong-api"], model_names: List[str] = ["tiangong-api"],
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse", version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768) kwargs.setdefault("context_len", 32768)
@ -34,18 +33,18 @@ class TianGongWorker(ApiModelWorker):
data = { data = {
"messages": params.messages, "messages": params.messages,
"model": "SkyChat-MegaVerse" "model": "SkyChat-MegaVerse"
} }
timestamp = str(int(time.time())) timestamp = str(int(time.time()))
sign_content = params.api_key + params.secret_key + timestamp sign_content = params.api_key + params.secret_key + timestamp
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest() sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
headers={ headers = {
"app_key": params.api_key, "app_key": params.api_key,
"timestamp": timestamp, "timestamp": timestamp,
"sign": sign_result, "sign": sign_result,
"Content-Type": "application/json", "Content-Type": "application/json",
"stream": "true" # or change to "false" 不处理流式返回内容 "stream": "true" # or change to "false" 不处理流式返回内容
} }
# 发起请求并获取响应 # 发起请求并获取响应
response = requests.post(url, headers=headers, json=data, stream=True) response = requests.post(url, headers=headers, json=data, stream=True)
@ -56,17 +55,17 @@ class TianGongWorker(ApiModelWorker):
# 处理接收到的数据 # 处理接收到的数据
# print(line.decode('utf-8')) # print(line.decode('utf-8'))
resp = json.loads(line) resp = json.loads(line)
if resp["code"] == 200: if resp["code"] == 200:
text += resp['resp_data']['reply'] text += resp['resp_data']['reply']
yield { yield {
"error_code": 0, "error_code": 0,
"text": text "text": text
} }
else: else:
data = { data = {
"error_code": resp["code"], "error_code": resp["code"],
"text": resp["code_msg"] "text": resp["code_msg"]
} }
self.logger.error(f"请求天工 API 时出错:{data}") self.logger.error(f"请求天工 API 时出错:{data}")
yield data yield data
@ -85,5 +84,3 @@ class TianGongWorker(ApiModelWorker):
sep="\n### ", sep="\n### ",
stop_str="###", stop_str="###",
) )

View File

@ -37,7 +37,7 @@ class XingHuoWorker(ApiModelWorker):
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000需要自行修改 kwargs.setdefault("context_len", 8000)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version

View File

@ -4,93 +4,86 @@ from fastchat import conversation as conv
import sys import sys
from typing import List, Dict, Iterator, Literal from typing import List, Dict, Iterator, Literal
from configs import logger, log_verbose from configs import logger, log_verbose
import requests
import jwt
import time
import json
def generate_token(apikey: str, exp_seconds: int):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
class ChatGLMWorker(ApiModelWorker): class ChatGLMWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text_embedding"
def __init__( def __init__(
self, self,
*, *,
model_names: List[str] = ["zhipu-api"], model_names: List[str] = ["zhipu-api"],
controller_addr: str = None, controller_addr: str = None,
worker_addr: str = None, worker_addr: str = None,
version: Literal["chatglm_turbo"] = "chatglm_turbo", version: Literal["chatglm_turbo"] = "chatglm_turbo",
**kwargs, **kwargs,
): ):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768) kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]: def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
# TODO: 维护request_id
import zhipuai
params.load_config(self.model_names[0]) params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key token = generate_token(params.api_key, 60)
headers = {
if log_verbose: "Content-Type": "application/json",
logger.info(f'{self.__class__.__name__}:params: {params}') "Authorization": f"Bearer {token}"
}
response = zhipuai.model_api.sse_invoke( data = {
model=params.version, "model": params.version,
prompt=params.messages, "messages": params.messages,
temperature=params.temperature, "max_tokens": params.max_tokens,
top_p=params.top_p, "temperature": params.temperature,
incremental=False, "stream": True
) }
for e in response.events(): url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
if e.event == "add": response = requests.post(url, headers=headers, json=data, stream=True)
yield {"error_code": 0, "text": e.data} for chunk in response.iter_lines():
elif e.event in ["error", "interrupted"]: if chunk:
data = { chunk_str = chunk.decode('utf-8')
"error_code": 500, json_start_pos = chunk_str.find('{"id"')
"text": e.data, if json_start_pos != -1:
"error": { json_str = chunk_str[json_start_pos:]
"message": e.data, json_data = json.loads(json_str)
"type": "invalid_request_error", for choice in json_data.get('choices', []):
"param": None, delta = choice.get('delta', {})
"code": None, content = delta.get('content', '')
} yield {"error_code": 0, "text": content}
}
self.logger.error(f"请求智谱 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import zhipuai
params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key
embeddings = []
try:
for t in params.texts:
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
if response["code"] == 200:
embeddings.append(response["data"]["embedding"])
else:
self.logger.error(f"请求智谱 API 时发生错误:{response}")
return response # dict with code & msg
except Exception as e:
self.logger.error(f"请求智谱 API 时发生错误:{data}")
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return data
return {"code": 200, "data": embeddings}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # 临时解决方案不支持embedding
print("embedding") print("embedding")
# print(params) print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# 这里的是chatglm api的模板其它API的conv_template需要定制
return conv.Conversation( return conv.Conversation(
name=self.model_names[0], name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务", system_message="你是智谱AI小助手请根据用户的提示来完成任务",
messages=[], messages=[],
roles=["Human", "Assistant", "System"], roles=["user", "assistant", "system"],
sep="\n###", sep="\n###",
stop_str="###", stop_str="###",
) )

File diff suppressed because one or more lines are too long

View File

@ -10,10 +10,24 @@ from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI, AzureOpenAI, Anthropic from langchain.llms import OpenAI
import httpx import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple from typing import (
TYPE_CHECKING,
Literal,
Optional,
Callable,
Generator,
Dict,
Any,
Awaitable,
Union,
Tuple
)
import logging import logging
import torch
from server.minx_chat_openai import MinxChatOpenAI
async def wrap_done(fn: Awaitable, event: asyncio.Event): async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -43,7 +57,7 @@ def get_ChatOpenAI(
config = get_model_worker_config(model_name) config = get_model_worker_config(model_name)
if model_name == "openai-api": if model_name == "openai-api":
model_name = config.get("model_name") model_name = config.get("model_name")
ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model
model = ChatOpenAI( model = ChatOpenAI(
streaming=streaming, streaming=streaming,
verbose=verbose, verbose=verbose,
@ -58,6 +72,7 @@ def get_ChatOpenAI(
) )
return model return model
def get_OpenAI( def get_OpenAI(
model_name: str, model_name: str,
temperature: float, temperature: float,
@ -488,16 +503,12 @@ def set_httpx_config(
no_proxy.append(host) no_proxy.append(host)
os.environ["NO_PROXY"] = ",".join(no_proxy) os.environ["NO_PROXY"] = ",".join(no_proxy)
# TODO: 简单的清除系统代理不是个好的选择影响太多。似乎修改代理服务器的bypass列表更好。
# patch requests to use custom proxies instead of system settings
def _get_proxies(): def _get_proxies():
return proxies return proxies
import urllib.request import urllib.request
urllib.request.getproxies = _get_proxies urllib.request.getproxies = _get_proxies
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]: def detect_device() -> Literal["cuda", "mps", "cpu"]:
try: try:

View File

@ -6,9 +6,8 @@ import sys
from multiprocessing import Process from multiprocessing import Process
from datetime import datetime from datetime import datetime
from pprint import pprint from pprint import pprint
from langchain_core._api import deprecated
# 设置numexpr最大线程数默认为CPU核心数
try: try:
import numexpr import numexpr
@ -33,15 +32,18 @@ from configs import (
HTTPX_DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT,
) )
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_config, get_httpx_client, fschat_openai_api_address, get_httpx_client, get_model_worker_config,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device) MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
from server.knowledge_base.migrate import create_tables from server.knowledge_base.migrate import create_tables
import argparse import argparse
from typing import Tuple, List, Dict from typing import List, Dict
from configs import VERSION from configs import VERSION
@deprecated(
since="0.3.0",
message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动0.2.x中相关功能将废弃",
removal="0.3.0")
def create_controller_app( def create_controller_app(
dispatch_method: str, dispatch_method: str,
log_level: str = "INFO", log_level: str = "INFO",
@ -88,7 +90,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(args, k, v) setattr(args, k, v)
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作 if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
from fastchat.serve.base_model_worker import app from fastchat.serve.base_model_worker import app
worker = "" worker = ""
# 在线模型API # 在线模型API
@ -107,12 +109,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
import fastchat.serve.vllm_worker import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 args.tokenizer = args.model_path
args.tokenizer_mode = 'auto' args.tokenizer_mode = 'auto'
args.trust_remote_code= True args.trust_remote_code = True
args.download_dir= None args.download_dir = None
args.load_format = 'auto' args.load_format = 'auto'
args.dtype = 'auto' args.dtype = 'auto'
args.seed = 0 args.seed = 0
@ -122,13 +124,13 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.block_size = 16 args.block_size = 16
args.swap_space = 4 # GiB args.swap_space = 4 # GiB
args.gpu_memory_utilization = 0.90 args.gpu_memory_utilization = 0.90
args.max_num_batched_tokens = None # 一个批次中的最大令牌tokens数量这个取决于你的显卡和大模型设置设置太大显存会不够 args.max_num_batched_tokens = None # 一个批次中的最大令牌tokens数量这个取决于你的显卡和大模型设置设置太大显存会不够
args.max_num_seqs = 256 args.max_num_seqs = 256
args.disable_log_stats = False args.disable_log_stats = False
args.conv_template = None args.conv_template = None
args.limit_worker_concurrency = 5 args.limit_worker_concurrency = 5
args.no_register = False args.no_register = False
args.num_gpus = 1 # vllm worker的切分是tensor并行这里填写显卡的数量 args.num_gpus = 1 # vllm worker的切分是tensor并行这里填写显卡的数量
args.engine_use_ray = False args.engine_use_ray = False
args.disable_log_requests = False args.disable_log_requests = False
@ -138,10 +140,10 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.quantization = None args.quantization = None
args.max_log_len = None args.max_log_len = None
args.tokenizer_revision = None args.tokenizer_revision = None
# 0.2.2 vllm需要新加的参数 # 0.2.2 vllm需要新加的参数
args.max_paddings = 256 args.max_paddings = 256
if args.model_path: if args.model_path:
args.model = args.model_path args.model = args.model_path
if args.num_gpus > 1: if args.num_gpus > 1:
@ -154,16 +156,16 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
worker = VLLMWorker( worker = VLLMWorker(
controller_addr = args.controller_address, controller_addr=args.controller_address,
worker_addr = args.worker_address, worker_addr=args.worker_address,
worker_id = worker_id, worker_id=worker_id,
model_path = args.model_path, model_path=args.model_path,
model_names = args.model_names, model_names=args.model_names,
limit_worker_concurrency = args.limit_worker_concurrency, limit_worker_concurrency=args.limit_worker_concurrency,
no_register = args.no_register, no_register=args.no_register,
llm_engine = engine, llm_engine=engine,
conv_template = args.conv_template, conv_template=args.conv_template,
) )
sys.modules["fastchat.serve.vllm_worker"].engine = engine sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker sys.modules["fastchat.serve.vllm_worker"].worker = worker
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level) sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
@ -171,7 +173,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
else: else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3" args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3"
args.max_gpu_memory = "22GiB" args.max_gpu_memory = "22GiB"
args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量 args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量
@ -325,7 +327,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
with get_httpx_client() as client: with get_httpx_client() as client:
r = client.post(worker_address + "/release", r = client.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin}) json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200: if r.status_code != 200:
msg = f"failed to release model: {model_name}" msg = f"failed to release model: {model_name}"
logger.error(msg) logger.error(msg)
@ -393,8 +395,8 @@ def run_model_worker(
# add interface to release and load model # add interface to release and load model
@app.post("/release") @app.post("/release")
def release_model( def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"), new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型") keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict: ) -> Dict:
if keep_origin: if keep_origin:
if new_model_name: if new_model_name:
@ -450,13 +452,13 @@ def run_webui(started_event: mp.Event = None, run_mode: str = None):
port = WEBUI_SERVER["port"] port = WEBUI_SERVER["port"]
cmd = ["streamlit", "run", "webui.py", cmd = ["streamlit", "run", "webui.py",
"--server.address", host, "--server.address", host,
"--server.port", str(port), "--server.port", str(port),
"--theme.base", "light", "--theme.base", "light",
"--theme.primaryColor", "#165dff", "--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5", "--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000", "--theme.textColor", "#000000",
] ]
if run_mode == "lite": if run_mode == "lite":
cmd += [ cmd += [
"--", "--",
@ -605,8 +607,10 @@ async def start_main_server():
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed. Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose. Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
""" """
def f(signal_received, frame): def f(signal_received, frame):
raise KeyboardInterrupt(f"{signalname} received") raise KeyboardInterrupt(f"{signalname} received")
return f return f
# This will be inherited by the child process if it is forked (not spawned) # This will be inherited by the child process if it is forked (not spawned)
@ -701,8 +705,8 @@ async def start_main_server():
for model_name in args.model_name: for model_name in args.model_name:
config = get_model_worker_config(model_name) config = get_model_worker_config(model_name)
if (config.get("online_api") if (config.get("online_api")
and config.get("worker_class") and config.get("worker_class")
and model_name in FSCHAT_MODEL_WORKERS): and model_name in FSCHAT_MODEL_WORKERS):
e = manager.Event() e = manager.Event()
model_worker_started.append(e) model_worker_started.append(e)
process = Process( process = Process(
@ -742,12 +746,12 @@ async def start_main_server():
else: else:
try: try:
# 保证任务收到SIGINT后能够正常退出 # 保证任务收到SIGINT后能够正常退出
if p:= processes.get("controller"): if p := processes.get("controller"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
controller_started.wait() # 等待controller启动完成 controller_started.wait() # 等待controller启动完成
if p:= processes.get("openai_api"): if p := processes.get("openai_api"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
@ -763,24 +767,24 @@ async def start_main_server():
for e in model_worker_started: for e in model_worker_started:
e.wait() e.wait()
if p:= processes.get("api"): if p := processes.get("api"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
api_started.wait() # 等待api.py启动完成 api_started.wait() # 等待api.py启动完成
if p:= processes.get("webui"): if p := processes.get("webui"):
p.start() p.start()
p.name = f"{p.name} ({p.pid})" p.name = f"{p.name} ({p.pid})"
webui_started.wait() # 等待webui.py启动完成 webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args) dump_server_info(after_start=True, args=args)
while True: while True:
cmd = queue.get() # 收到切换模型的消息 cmd = queue.get() # 收到切换模型的消息
e = manager.Event() e = manager.Event()
if isinstance(cmd, list): if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd model_name, cmd, new_model_name = cmd
if cmd == "start": # 运行新模型 if cmd == "start": # 运行新模型
logger.info(f"准备启动新模型进程:{new_model_name}") logger.info(f"准备启动新模型进程:{new_model_name}")
process = Process( process = Process(
target=run_model_worker, target=run_model_worker,
@ -831,7 +835,6 @@ async def start_main_server():
else: else:
logger.error(f"未找到模型进程:{model_name}") logger.error(f"未找到模型进程:{model_name}")
# for process in processes.get("model_worker", {}).values(): # for process in processes.get("model_worker", {}).values():
# process.join() # process.join()
# for process in processes.get("online_api", {}).values(): # for process in processes.get("online_api", {}).values():
@ -866,10 +869,9 @@ async def start_main_server():
for p in processes.values(): for p in processes.values():
logger.info("Process status: %s", p) logger.info("Process status: %s", p)
if __name__ == "__main__":
# 确保数据库表被创建
create_tables()
if __name__ == "__main__":
create_tables()
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
else: else:
@ -879,16 +881,15 @@ if __name__ == "__main__":
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
# 同步调用协程代码
loop.run_until_complete(start_main_server())
loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例: # 服务启动后接口调用示例:
# import openai # import openai
# openai.api_key = "EMPTY" # Not support yet # openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1" # openai.api_base = "http://localhost:8888/v1"
# model = "chatglm2-6b" # model = "chatglm3-6b"
# # create a chat completion # # create a chat completion
# completion = openai.ChatCompletion.create( # completion = openai.ChatCompletion.create(

BIN
tests/samples/ocr_test.docx Normal file

Binary file not shown.

BIN
tests/samples/ocr_test.pptx Normal file

Binary file not shown.

View File

@ -14,8 +14,7 @@ from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder
# setup test knowledge base # setup test knowledge base
kb_name = "test_kb_for_migrate" kb_name = "test_kb_for_migrate"
test_files = { test_files = {
"faq.md": str(root_path / "docs" / "faq.md"), "readme.md": str(root_path / "readme.md"),
"install.md": str(root_path / "docs" / "install.md"),
} }
@ -56,13 +55,13 @@ def test_recreate_vs():
assert doc.metadata["source"] == name assert doc.metadata["source"] == name
def test_increament(): def test_increment():
kb = KBServiceFactory.get_service_by_name(kb_name) kb = KBServiceFactory.get_service_by_name(kb_name)
kb.clear_vs() kb.clear_vs()
assert kb.list_files() == [] assert kb.list_files() == []
assert kb.list_docs() == [] assert kb.list_docs() == []
folder2db([kb_name], "increament") folder2db([kb_name], "increment")
files = kb.list_files() files = kb.list_files()
print(files) print(files)

View File

@ -1,7 +1,6 @@
# 该文件封装了对api.py的请求可以被不同的webui使用 # 该文件封装了对api.py的请求可以被不同的webui使用
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用 # 通过ApiRequest和AsyncApiRequest支持同步/异步调用
from typing import * from typing import *
from pathlib import Path from pathlib import Path
# 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同 # 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
@ -27,7 +26,7 @@ from io import BytesIO
from server.utils import set_httpx_config, api_address, get_httpx_client from server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint from pprint import pprint
from langchain_core._api import deprecated
set_httpx_config() set_httpx_config()
@ -36,10 +35,11 @@ class ApiRequest:
''' '''
api.py调用的封装同步模式,简化api调用方式 api.py调用的封装同步模式,简化api调用方式
''' '''
def __init__( def __init__(
self, self,
base_url: str = api_address(), base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT, timeout: float = HTTPX_DEFAULT_TIMEOUT,
): ):
self.base_url = base_url self.base_url = base_url
self.timeout = timeout self.timeout = timeout
@ -55,12 +55,12 @@ class ApiRequest:
return self._client return self._client
def get( def get(
self, self,
url: str, url: str,
params: Union[Dict, List[Tuple], bytes] = None, params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3, retry: int = 3,
stream: bool = False, stream: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0: while retry > 0:
try: try:
@ -75,13 +75,13 @@ class ApiRequest:
retry -= 1 retry -= 1
def post( def post(
self, self,
url: str, url: str,
data: Dict = None, data: Dict = None,
json: Dict = None, json: Dict = None,
retry: int = 3, retry: int = 3,
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0: while retry > 0:
try: try:
@ -97,13 +97,13 @@ class ApiRequest:
retry -= 1 retry -= 1
def delete( def delete(
self, self,
url: str, url: str,
data: Dict = None, data: Dict = None,
json: Dict = None, json: Dict = None,
retry: int = 3, retry: int = 3,
stream: bool = False, stream: bool = False,
**kwargs: Any **kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]: ) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0: while retry > 0:
try: try:
@ -118,30 +118,33 @@ class ApiRequest:
retry -= 1 retry -= 1
def _httpx_stream2generator( def _httpx_stream2generator(
self, self,
response: contextlib._GeneratorContextManager, response: contextlib._GeneratorContextManager,
as_json: bool = False, as_json: bool = False,
): ):
''' '''
将httpx.stream返回的GeneratorContextManager转化为普通生成器 将httpx.stream返回的GeneratorContextManager转化为普通生成器
''' '''
async def ret_async(response, as_json): async def ret_async(response, as_json):
try: try:
async with response as r: async with response as r:
async for chunk in r.aiter_text(None): async for chunk in r.aiter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end if not chunk: # fastchat api yield empty bytes on start and end
continue continue
if as_json: if as_json:
try: try:
if chunk.startswith("data: "): if chunk.startswith("data: "):
data = json.loads(chunk[6:-2]) data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else: else:
data = json.loads(chunk) data = json.loads(chunk)
yield data yield data
except Exception as e: except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}" msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
else: else:
# print(chunk, end="", flush=True) # print(chunk, end="", flush=True)
yield chunk yield chunk
@ -156,26 +159,28 @@ class ApiRequest:
except Exception as e: except Exception as e:
msg = f"API通信遇到错误{e}" msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg} yield {"code": 500, "msg": msg}
def ret_sync(response, as_json): def ret_sync(response, as_json):
try: try:
with response as r: with response as r:
for chunk in r.iter_text(None): for chunk in r.iter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end if not chunk: # fastchat api yield empty bytes on start and end
continue continue
if as_json: if as_json:
try: try:
if chunk.startswith("data: "): if chunk.startswith("data: "):
data = json.loads(chunk[6:-2]) data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else: else:
data = json.loads(chunk) data = json.loads(chunk)
yield data yield data
except Exception as e: except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}" msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
else: else:
# print(chunk, end="", flush=True) # print(chunk, end="", flush=True)
yield chunk yield chunk
@ -190,7 +195,7 @@ class ApiRequest:
except Exception as e: except Exception as e:
msg = f"API通信遇到错误{e}" msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg} yield {"code": 500, "msg": msg}
if self._use_async: if self._use_async:
@ -199,16 +204,17 @@ class ApiRequest:
return ret_sync(response, as_json) return ret_sync(response, as_json)
def _get_response_value( def _get_response_value(
self, self,
response: httpx.Response, response: httpx.Response,
as_json: bool = False, as_json: bool = False,
value_func: Callable = None, value_func: Callable = None,
): ):
''' '''
转换同步或异步请求返回的响应 转换同步或异步请求返回的响应
`as_json`: 返回json `as_json`: 返回json
`value_func`: 用户可以自定义返回值该函数接受response或json `value_func`: 用户可以自定义返回值该函数接受response或json
''' '''
def to_json(r): def to_json(r):
try: try:
return r.json() return r.json()
@ -216,7 +222,7 @@ class ApiRequest:
msg = "API未能返回正确的JSON。" + str(e) msg = "API未能返回正确的JSON。" + str(e)
if log_verbose: if log_verbose:
logger.error(f'{e.__class__.__name__}: {msg}', logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None) exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg, "data": None} return {"code": 500, "msg": msg, "data": None}
if value_func is None: if value_func is None:
@ -246,10 +252,10 @@ class ApiRequest:
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"]) return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
def get_prompt_template( def get_prompt_template(
self, self,
type: str = "llm_chat", type: str = "llm_chat",
name: str = "default", name: str = "default",
**kwargs, **kwargs,
) -> str: ) -> str:
data = { data = {
"type": type, "type": type,
@ -293,15 +299,19 @@ class ApiRequest:
response = self.post("/chat/chat", json=data, stream=True, **kwargs) response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0")
def agent_chat( def agent_chat(
self, self,
query: str, query: str,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODELS[0], model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = None, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
对应api.py/chat/agent_chat 接口 对应api.py/chat/agent_chat 接口
@ -320,20 +330,20 @@ class ApiRequest:
# pprint(data) # pprint(data)
response = self.post("/chat/agent_chat", json=data, stream=True) response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response) return self._httpx_stream2generator(response, as_json=True)
def knowledge_base_chat( def knowledge_base_chat(
self, self,
query: str, query: str,
knowledge_base_name: str, knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODELS[0], model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = None, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
对应api.py/chat/knowledge_base_chat接口 对应api.py/chat/knowledge_base_chat接口
@ -362,28 +372,29 @@ class ApiRequest:
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
def upload_temp_docs( def upload_temp_docs(
self, self,
files: List[Union[str, Path, bytes]], files: List[Union[str, Path, bytes]],
knowledge_id: str = None, knowledge_id: str = None,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
): ):
''' '''
对应api.py/knowledge_base/upload_tmep_docs接口 对应api.py/knowledge_base/upload_tmep_docs接口
''' '''
def convert_file(file, filename=None): def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes if isinstance(file, bytes): # raw bytes
file = BytesIO(file) file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object elif hasattr(file, "read"): # a file io like object
filename = filename or file.name filename = filename or file.name
else: # a local path else: # a local path
file = Path(file).absolute().open("rb") file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1] filename = filename or os.path.split(file.name)[-1]
return filename, file return filename, file
files = [convert_file(file) for file in files] files = [convert_file(file) for file in files]
data={ data = {
"knowledge_id": knowledge_id, "knowledge_id": knowledge_id,
"chunk_size": chunk_size, "chunk_size": chunk_size,
"chunk_overlap": chunk_overlap, "chunk_overlap": chunk_overlap,
@ -398,17 +409,17 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def file_chat( def file_chat(
self, self,
query: str, query: str,
knowledge_id: str, knowledge_id: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODELS[0], model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = None, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
对应api.py/chat/file_chat接口 对应api.py/chat/file_chat接口
@ -436,18 +447,23 @@ class ApiRequest:
) )
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def search_engine_chat( def search_engine_chat(
self, self,
query: str, query: str,
search_engine_name: str, search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K, top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
model: str = LLM_MODELS[0], model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = None, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
split_result: bool = False, split_result: bool = False,
): ):
''' '''
对应api.py/chat/search_engine_chat接口 对应api.py/chat/search_engine_chat接口
@ -478,7 +494,7 @@ class ApiRequest:
# 知识库相关操作 # 知识库相关操作
def list_knowledge_bases( def list_knowledge_bases(
self, self,
): ):
''' '''
对应api.py/knowledge_base/list_knowledge_bases接口 对应api.py/knowledge_base/list_knowledge_bases接口
@ -489,10 +505,10 @@ class ApiRequest:
value_func=lambda r: r.get("data", [])) value_func=lambda r: r.get("data", []))
def create_knowledge_base( def create_knowledge_base(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
vector_store_type: str = DEFAULT_VS_TYPE, vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
): ):
''' '''
对应api.py/knowledge_base/create_knowledge_base接口 对应api.py/knowledge_base/create_knowledge_base接口
@ -510,8 +526,8 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def delete_knowledge_base( def delete_knowledge_base(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
): ):
''' '''
对应api.py/knowledge_base/delete_knowledge_base接口 对应api.py/knowledge_base/delete_knowledge_base接口
@ -523,8 +539,8 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def list_kb_docs( def list_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
): ):
''' '''
对应api.py/knowledge_base/list_files接口 对应api.py/knowledge_base/list_files接口
@ -538,13 +554,13 @@ class ApiRequest:
value_func=lambda r: r.get("data", [])) value_func=lambda r: r.get("data", []))
def search_kb_docs( def search_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
query: str = "", query: str = "",
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD, score_threshold: int = SCORE_THRESHOLD,
file_name: str = "", file_name: str = "",
metadata: dict = {}, metadata: dict = {},
) -> List: ) -> List:
''' '''
对应api.py/knowledge_base/search_docs接口 对应api.py/knowledge_base/search_docs接口
@ -565,9 +581,9 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def update_docs_by_id( def update_docs_by_id(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
docs: Dict[str, Dict], docs: Dict[str, Dict],
) -> bool: ) -> bool:
''' '''
对应api.py/knowledge_base/update_docs_by_id接口 对应api.py/knowledge_base/update_docs_by_id接口
@ -583,32 +599,33 @@ class ApiRequest:
return self._get_response_value(response) return self._get_response_value(response)
def upload_kb_docs( def upload_kb_docs(
self, self,
files: List[Union[str, Path, bytes]], files: List[Union[str, Path, bytes]],
knowledge_base_name: str, knowledge_base_name: str,
override: bool = False, override: bool = False,
to_vector_store: bool = True, to_vector_store: bool = True,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {}, docs: Dict = {},
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
): ):
''' '''
对应api.py/knowledge_base/upload_docs接口 对应api.py/knowledge_base/upload_docs接口
''' '''
def convert_file(file, filename=None): def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes if isinstance(file, bytes): # raw bytes
file = BytesIO(file) file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object elif hasattr(file, "read"): # a file io like object
filename = filename or file.name filename = filename or file.name
else: # a local path else: # a local path
file = Path(file).absolute().open("rb") file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1] filename = filename or os.path.split(file.name)[-1]
return filename, file return filename, file
files = [convert_file(file) for file in files] files = [convert_file(file) for file in files]
data={ data = {
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"override": override, "override": override,
"to_vector_store": to_vector_store, "to_vector_store": to_vector_store,
@ -629,11 +646,11 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def delete_kb_docs( def delete_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
file_names: List[str], file_names: List[str],
delete_content: bool = False, delete_content: bool = False,
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
): ):
''' '''
对应api.py/knowledge_base/delete_docs接口 对应api.py/knowledge_base/delete_docs接口
@ -651,8 +668,7 @@ class ApiRequest:
) )
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def update_kb_info(self, knowledge_base_name, kb_info):
def update_kb_info(self,knowledge_base_name,kb_info):
''' '''
对应api.py/knowledge_base/update_info接口 对应api.py/knowledge_base/update_info接口
''' '''
@ -668,15 +684,15 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def update_kb_docs( def update_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
file_names: List[str], file_names: List[str],
override_custom_docs: bool = False, override_custom_docs: bool = False,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {}, docs: Dict = {},
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
): ):
''' '''
对应api.py/knowledge_base/update_docs接口 对应api.py/knowledge_base/update_docs接口
@ -702,14 +718,14 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def recreate_vector_store( def recreate_vector_store(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
allow_empty_kb: bool = True, allow_empty_kb: bool = True,
vs_type: str = DEFAULT_VS_TYPE, vs_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
chunk_size=CHUNK_SIZE, chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE, chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE, zh_title_enhance=ZH_TITLE_ENHANCE,
): ):
''' '''
对应api.py/knowledge_base/recreate_vector_store接口 对应api.py/knowledge_base/recreate_vector_store接口
@ -734,8 +750,8 @@ class ApiRequest:
# LLM模型相关操作 # LLM模型相关操作
def list_running_models( def list_running_models(
self, self,
controller_address: str = None, controller_address: str = None,
): ):
''' '''
获取Fastchat中正运行的模型列表 获取Fastchat中正运行的模型列表
@ -751,8 +767,7 @@ class ApiRequest:
"/llm_model/list_running_models", "/llm_model/list_running_models",
json=data, json=data,
) )
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", [])) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", []))
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]: def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
''' '''
@ -760,6 +775,7 @@ class ApiRequest:
local_first=True 优先返回运行中的本地模型否则优先按LLM_MODELS配置顺序返回 local_first=True 优先返回运行中的本地模型否则优先按LLM_MODELS配置顺序返回
返回类型为model_name, is_local_model 返回类型为model_name, is_local_model
''' '''
def ret_sync(): def ret_sync():
running_models = self.list_running_models() running_models = self.list_running_models()
if not running_models: if not running_models:
@ -776,7 +792,7 @@ class ApiRequest:
model = m model = m
break break
if not model: # LLM_MODELS中配置的模型都不在running_models里 if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0] model = list(running_models)[0]
is_local = not running_models[model].get("online_api") is_local = not running_models[model].get("online_api")
return model, is_local return model, is_local
@ -797,7 +813,7 @@ class ApiRequest:
model = m model = m
break break
if not model: # LLM_MODELS中配置的模型都不在running_models里 if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0] model = list(running_models)[0]
is_local = not running_models[model].get("online_api") is_local = not running_models[model].get("online_api")
return model, is_local return model, is_local
@ -808,8 +824,8 @@ class ApiRequest:
return ret_sync() return ret_sync()
def list_config_models( def list_config_models(
self, self,
types: List[str] = ["local", "online"], types: List[str] = ["local", "online"],
) -> Dict[str, Dict]: ) -> Dict[str, Dict]:
''' '''
获取服务器configs中配置的模型列表返回形式为{"type": {model_name: config}, ...} 获取服务器configs中配置的模型列表返回形式为{"type": {model_name: config}, ...}
@ -821,23 +837,23 @@ class ApiRequest:
"/llm_model/list_config_models", "/llm_model/list_config_models",
json=data, json=data,
) )
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def get_model_config( def get_model_config(
self, self,
model_name: str = None, model_name: str = None,
) -> Dict: ) -> Dict:
''' '''
获取服务器上模型配置 获取服务器上模型配置
''' '''
data={ data = {
"model_name": model_name, "model_name": model_name,
} }
response = self.post( response = self.post(
"/llm_model/get_model_config", "/llm_model/get_model_config",
json=data, json=data,
) )
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def list_search_engines(self) -> List[str]: def list_search_engines(self) -> List[str]:
''' '''
@ -846,12 +862,12 @@ class ApiRequest:
response = self.post( response = self.post(
"/server/list_search_engines", "/server/list_search_engines",
) )
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def stop_llm_model( def stop_llm_model(
self, self,
model_name: str, model_name: str,
controller_address: str = None, controller_address: str = None,
): ):
''' '''
停止某个LLM模型 停止某个LLM模型
@ -869,10 +885,10 @@ class ApiRequest:
return self._get_response_value(response, as_json=True) return self._get_response_value(response, as_json=True)
def change_llm_model( def change_llm_model(
self, self,
model_name: str, model_name: str,
new_model_name: str, new_model_name: str,
controller_address: str = None, controller_address: str = None,
): ):
''' '''
向fastchat controller请求切换LLM模型 向fastchat controller请求切换LLM模型
@ -955,10 +971,10 @@ class ApiRequest:
return ret_sync() return ret_sync()
def embed_texts( def embed_texts(
self, self,
texts: List[str], texts: List[str],
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
to_query: bool = False, to_query: bool = False,
) -> List[List[float]]: ) -> List[List[float]]:
''' '''
对文本进行向量化可选模型包括本地 embed_models 和支持 embeddings 的在线模型 对文本进行向量化可选模型包括本地 embed_models 和支持 embeddings 的在线模型
@ -975,10 +991,10 @@ class ApiRequest:
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
def chat_feedback( def chat_feedback(
self, self,
message_id: str, message_id: str,
score: int, score: int,
reason: str = "", reason: str = "",
) -> int: ) -> int:
''' '''
反馈对话评价 反馈对话评价
@ -1015,9 +1031,9 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
return error message if error occured when requests API return error message if error occured when requests API
''' '''
if (isinstance(data, dict) if (isinstance(data, dict)
and key in data and key in data
and "code" in data and "code" in data
and data["code"] == 200): and data["code"] == 200):
return data[key] return data[key]
return "" return ""