Merge pull request #2748 from chatchat-space/dev

0.2.10(0.2.x Final dev release)
This commit is contained in:
zR 2024-01-22 11:57:20 +08:00 committed by GitHub
commit 09bcec542a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 1106 additions and 1291 deletions

View File

@ -42,7 +42,7 @@
🚩 本项目未涉及微调、训练过程,但可利用微调或训练对本项目效果进行优化。
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v11` 版本所使用代码已更新至本项目 `v0.2.7` 版本。
🌐 [AutoDL 镜像](https://www.codewithgpu.com/i/chatchat-space/Langchain-Chatchat/Langchain-Chatchat) 中 `v13` 版本所使用代码已更新至本项目 `v0.2.9` 版本。
🐳 [Docker 镜像](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.6) 已经更新到 ```0.2.7``` 版本。
@ -67,10 +67,10 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
### 1. 环境配置
+ 首先,确保你的机器安装了 Python 3.8 - 3.10
+ 首先,确保你的机器安装了 Python 3.8 - 3.11
```
$ python --version
Python 3.10.12
Python 3.11.7
```
接着,创建一个虚拟环境,并在虚拟环境内安装项目的依赖
```shell
@ -88,6 +88,7 @@ $ pip install -r requirements_webui.txt
# 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
```
请注意LangChain-Chatchat `0.2.x` 系列是针对 Langchain `0.0.x` 系列版本的,如果你使用的是 Langchain `0.1.x` 系列版本,需要降级。
### 2 模型下载
如需在本地或离线环境下运行本项目,需要首先将项目所需的模型下载至本地,通常开源 LLM 与 Embedding 模型可以从 [HuggingFace](https://huggingface.co/models) 下载。
@ -141,14 +142,20 @@ $ python startup.py -a
---
## 项目里程碑
+ `2023年4月`: `Langchain-ChatGLM 0.1.0` 发布,支持基于 ChatGLM-6B 模型的本地知识库问答。
+ `2023年8月`: `Langchain-ChatGLM` 改名为 `Langchain-Chatchat``0.2.0` 发布,使用 `fastchat` 作为模型加载方案,支持更多的模型和数据库。
+ `2023年10月`: `Langchain-Chatchat 0.2.5` 发布,推出 Agent 内容,开源项目在`Founder Park & Zhipu AI & Zilliz` 举办的黑客马拉松获得三等奖。
+ `2023年12月`: `Langchain-Chatchat` 开源项目获得超过 **20K** stars.
+ `2024年1月`: `LangChain 0.1.x` 推出,`Langchain-Chatchat 0.2.x` 停止更新和技术支持,全力研发具有更强应用性的 `Langchain-Chatchat 0.3.x`
+ 🔥 让我们一起期待未来 Chatchat 的故事 ···
---
## 联系我们
### Telegram
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
### 项目交流群
<img src="img/qr_code_82.jpg" alt="二维码" width="300" />
<img src="img/qr_code_85.jpg" alt="二维码" width="300" />
🎉 Langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
@ -156,4 +163,4 @@ $ python startup.py -a
<img src="img/official_wechat_mp_account.png" alt="二维码" width="300" />
🎉 Langchain-Chatchat 项目官方公众号,欢迎扫码关注。
🎉 Langchain-Chatchat 项目官方公众号,欢迎扫码关注。

View File

@ -55,10 +55,10 @@ The main process analysis from the aspect of document process:
🚩 The training or fine-tuning are not involved in the project, but still, one always can improve performance by do
these.
🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v9 the codes are update
to v0.2.5.
🌐 [AutoDL image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5) is supported, and in v13 the codes are update
to v0.2.9.
🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5)
🐳 [Docker image](registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.7)
## Pain Points Addressed
@ -98,6 +98,7 @@ $ pip install -r requirements_webui.txt
# 默认依赖包括基本运行环境FAISS向量库。如果要使用 milvus/pg_vector 等向量库,请将 requirements.txt 中相应依赖取消注释再安装。
```
Please note that the LangChain-Chachat `0.2.x` series is for the Langchain `0.0.x` series version. If you are using the Langchain `0.1.x` series version, you need to downgrade.
### Model Download
@ -155,6 +156,16 @@ $ python startup.py -a
The above instructions are provided for a quick start. If you need more features or want to customize the launch method,
please refer to the [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/wiki/).
---
## Project Milestones
+ `April 2023`: `Langchain-ChatGLM 0.1.0` released, supporting local knowledge base question and answer based on the ChatGLM-6B model.
+ `August 2023`: `Langchain-ChatGLM` was renamed to `Langchain-Chatchat`, `0.2.0` was released, using `fastchat` as the model loading solution, supporting more models and databases.
+ `October 2023`: `Langchain-Chachat 0.2.5` was released, Agent content was launched, and the open source project won the third prize in the hackathon held by `Founder Park & Zhipu AI & Zilliz`.
+ `December 2023`: `Langchain-Chachat` open source project received more than **20K** stars.
+ `January 2024`: `LangChain 0.1.x` is launched, `Langchain-Chatchat 0.2.x` will stop updating and technical support, and all efforts will be made to develop `Langchain-Chatchat 0.3.x` with stronger applicability.
+ 🔥 Lets look forward to the future Chatchat stories together···
---
## Contact Us
@ -163,9 +174,9 @@ please refer to the [Wiki](https://github.com/chatchat-space/Langchain-Chatchat/
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white "langchain-chatglm")](https://t.me/+RjliQ3jnJ1YyN2E9)
### WeChat Group
### WeChat Group
<img src="img/qr_code_67.jpg" alt="二维码" width="300" height="300" />
<img src="img/qr_code_85.jpg" alt="二维码" width="300" height="300" />
### WeChat Official Account

View File

@ -5,4 +5,4 @@ from .server_config import *
from .prompt_config import *
VERSION = "v0.2.9-preview"
VERSION = "v0.2.10"

View File

@ -55,6 +55,9 @@ METAPHOR_API_KEY = ""
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
PDF_OCR_THRESHOLD = (0.6, 0.6)
# 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。
KB_INFO = {
@ -102,6 +105,10 @@ kbs_config = {
"index_name": "test_index",
"user": "",
"password": ""
},
"milvus_kwargs":{
"search_params":{"metric_type": "L2"}, #在此处增加search_params
"index_params":{"metric_type": "L2","index_type": "HNSW"} # 在此处增加index_params
}
}

View File

@ -6,9 +6,9 @@ import os
MODEL_ROOT_PATH = ""
# 选用的 Embedding 名称
EMBEDDING_MODEL = "bge-large-zh"
EMBEDDING_MODEL = "bge-large-zh-v1.5"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
# Embedding 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
EMBEDDING_DEVICE = "auto"
# 选用的reranker模型
@ -26,44 +26,33 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
# 在这里我们使用目前主流的两个离线模型其中chatglm3-6b 为默认加载模型。
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
# chatglm3-6b输出角色标签<|user|>及自问自答的问题详见项目wiki->常见问题->Q20.
LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",
# AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0])
LLM_MODELS = ["zhipu-api"]
Agent_MODEL = None
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
LLM_DEVICE = "auto"
# LLM 模型运行设备。设为"auto"会自动检测(会有警告),也可手动设定为 "cuda","mps","cpu","xpu" 其中之一。
LLM_DEVICE = "cuda"
# 历史对话轮数
HISTORY_LEN = 3
# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度
MAX_TOKENS = None
MAX_TOKENS = 2048
# LLM通用对话参数
TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
ONLINE_LLM_MODEL = {
# 线上模型。请在server_config中为每个在线API设置不同的端口
"openai-api": {
"model_name": "gpt-3.5-turbo",
"model_name": "gpt-4",
"api_base_url": "https://api.openai.com/v1",
"api_key": "",
"openai_proxy": "",
},
# 具体注册及api key获取请前往 http://open.bigmodel.cn
# 智谱AI API,具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": {
"api_key": "",
"version": "chatglm_turbo", # 可选包括 "chatglm_turbo"
"version": "glm-4",
"provider": "ChatGLMWorker",
},
# 具体注册及api key获取请前往 https://api.minimax.chat/
"minimax-api": {
"group_id": "",
@ -72,7 +61,6 @@ ONLINE_LLM_MODEL = {
"provider": "MiniMaxWorker",
},
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
"xinghuo-api": {
"APPID": "",
@ -93,8 +81,8 @@ ONLINE_LLM_MODEL = {
# 火山方舟 API文档参考 https://www.volcengine.com/docs/82379
"fangzhou-api": {
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model" 更多的见文档模型支持列表中方舟部分。
"version_url": "", # 可以不填写version直接填写在方舟申请模型发布的API地址
"version": "chatglm-6b-model",
"version_url": "",
"api_key": "",
"secret_key": "",
"provider": "FangZhouWorker",
@ -102,15 +90,15 @@ ONLINE_LLM_MODEL = {
# 阿里云通义千问 API文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
"qwen-api": {
"version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus"
"api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建
"version": "qwen-max",
"api_key": "",
"provider": "QwenWorker",
"embed_model": "text-embedding-v1" # embedding 模型名称
"embed_model": "text-embedding-v1" # embedding 模型名称
},
# 百川 API申请方式请参考 https://www.baichuan-ai.com/home#api-enter
"baichuan-api": {
"version": "Baichuan2-53B", # 当前支持 "Baichuan2-53B" 见官方文档。
"version": "Baichuan2-53B",
"api_key": "",
"secret_key": "",
"provider": "BaiChuanWorker",
@ -132,6 +120,11 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "TianGongWorker",
},
# Gemini API (开发组未测试由社群提供只支持prohttps://makersuite.google.com/或者google cloud使用前先确认网络正常使用代理请在项目启动python startup.py -a)环境内设置https_proxy环境变量
"gemini-api": {
"api_key": "",
"provider": "GeminiWorker",
}
}
@ -143,6 +136,7 @@ ONLINE_LLM_MODEL = {
# - GanymedeNil/text2vec-large-chinese
# - text2vec-large-chinese
# 2.2 如果以上本地路径不存在则使用huggingface模型
MODEL_PATH = {
"embed_model": {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
@ -161,7 +155,7 @@ MODEL_PATH = {
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"bge-base-zh-v1.5": "BAAI/bge-base-zh-v1.5",
"bge-large-zh-v1.5": "BAAI/bge-large-zh-v1.5",
"bge-large-zh-v1.5": "/share/home/zyx/Models/bge-large-zh-v1.5",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"nlp_gte_sentence-embedding_chinese-large": "damo/nlp_gte_sentence-embedding_chinese-large",
@ -169,55 +163,55 @@ MODEL_PATH = {
},
"llm_model": {
# 以下部分模型并未完全测试仅根据fastchat和vllm模型的模型列表推定支持
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"chatglm3-6b-base": "THUDM/chatglm3-6b-base",
"Qwen-1_8B": "Qwen/Qwen-1_8B",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-1_8B-Chat": "/media/checkpoint/Qwen-1_8B-Chat",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
# 在新版的transformers下需要手动修改模型的config.json文件在quantization_config字典中
# 增加`disable_exllama:true` 字段才能启动qwen的量化模型
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
"Qwen-72B": "Qwen/Qwen-72B",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"baichuan2-13b": "baichuan-inc/Baichuan2-13B-Chat",
"baichuan2-7b": "baichuan-inc/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"baichuan2-7b-chat": "baichuan-inc/Baichuan2-7B-Chat",
"baichuan2-13b-chat": "baichuan-inc/Baichuan2-13B-Chat",
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
"internlm2-chat-7b": "internlm/internlm2-chat-7b",
"internlm2-chat-20b": "internlm/internlm2-chat-20b",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
"falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b",
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.5": "lmsys/vicuna-13b-v1.5",
"koala": "young-geng/koala",
"mpt-7b": "mosaicml/mpt-7b",
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
"mpt-30b": "mosaicml/mpt-30b",
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"gpt2": "gpt2",
"gpt2-xl": "gpt2-xl",
"gpt-j-6b": "EleutherAI/gpt-j-6b",
"gpt4all-j": "nomic-ai/gpt4all-j",
"gpt-neox-20b": "EleutherAI/gpt-neox-20b",
@ -225,63 +219,50 @@ MODEL_PATH = {
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
"koala": "young-geng/koala",
"mpt-7b": "mosaicml/mpt-7b",
"mpt-7b-storywriter": "mosaicml/mpt-7b-storywriter",
"mpt-30b": "mosaicml/mpt-30b",
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
"Yi-34B-Chat": "01-ai/Yi-34B-Chat",
},
"reranker":{
"bge-reranker-large":"BAAI/bge-reranker-large",
"bge-reranker-base":"BAAI/bge-reranker-base",
#TODO 增加在线reranker如cohere
"reranker": {
"bge-reranker-large": "BAAI/bge-reranker-large",
"bge-reranker-base": "BAAI/bge-reranker-base",
}
}
# 通常情况下不需要更改以下内容
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 使用VLLM可能导致模型推理能力下降无法完成Agent任务
VLLM_MODEL_DICT = {
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"chatglm3-6b": "THUDM/chatglm3-6b",
"chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
"Llama-2-7b-chat-hf": "meta-llama/Llama-2-7b-chat-hf",
"Llama-2-13b-chat-hf": "meta-llama/Llama-2-13b-chat-hf",
"Llama-2-70b-chat-hf": "meta-llama/Llama-2-70b-chat-hf",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"baichuan-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-7b-chat": "baichuan-inc/Baichuan-7B-Chat",
"baichuan2-13b-chat": "baichuan-inc/Baichuan-13B-Chat",
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat",
"BlueLM-7B-Chat-32k": "vivo-ai/BlueLM-7B-Chat-32k",
# 注意bloom系列的tokenizer与model是分离的因此虽然vllm支持但与fschat框架不兼容
# "bloom": "bigscience/bloom",
# "bloomz": "bigscience/bloomz",
# "bloomz-560m": "bigscience/bloomz-560m",
# "bloomz-7b1": "bigscience/bloomz-7b1",
# "bloomz-1b7": "bigscience/bloomz-1b7",
"internlm-7b": "internlm/internlm-7b",
"internlm-chat-7b": "internlm/internlm-chat-7b",
"internlm2-chat-7b": "internlm/Models/internlm2-chat-7b",
"internlm2-chat-20b": "internlm/Models/internlm2-chat-20b",
"aquila-7b": "BAAI/Aquila-7B",
"aquilachat-7b": "BAAI/AquilaChat-7B",
"falcon-7b": "tiiuae/falcon-7b",
"falcon-40b": "tiiuae/falcon-40b",
"falcon-rw-7b": "tiiuae/falcon-rw-7b",
@ -294,8 +275,6 @@ VLLM_MODEL_DICT = {
"oasst-sft-4-pythia-12b-epoch-3.5": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b": "databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b": "stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf": "meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf": "meta-llama/Llama-2-70b-hf",
"open_llama_13b": "openlm-research/open_llama_13b",
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
"koala": "young-geng/koala",
@ -305,37 +284,12 @@ VLLM_MODEL_DICT = {
"opt-66b": "facebook/opt-66b",
"opt-iml-max-30b": "facebook/opt-iml-max-30b",
"Qwen-1_8B": "Qwen/Qwen-1_8B",
"Qwen-1_8B-Chat": "Qwen/Qwen-1_8B-Chat",
"Qwen-1_8B-Chat-Int8": "Qwen/Qwen-1_8B-Chat-Int8",
"Qwen-1_8B-Chat-Int4": "Qwen/Qwen-1_8B-Chat-Int4",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"Qwen-14B": "Qwen/Qwen-14B",
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
"Qwen-14B-Chat-Int8": "Qwen/Qwen-14B-Chat-Int8",
"Qwen-14B-Chat-Int4": "Qwen/Qwen-14B-Chat-Int4",
"Qwen-72B": "Qwen/Qwen-72B",
"Qwen-72B-Chat": "Qwen/Qwen-72B-Chat",
"Qwen-72B-Chat-Int8": "Qwen/Qwen-72B-Chat-Int8",
"Qwen-72B-Chat-Int4": "Qwen/Qwen-72B-Chat-Int4",
"agentlm-7b": "THUDM/agentlm-7b",
"agentlm-13b": "THUDM/agentlm-13b",
"agentlm-70b": "THUDM/agentlm-70b",
}
# 你认为支持Agent能力的模型可以在这里添加添加后不会出现可视化界面的警告
# 经过我们测试原生支持Agent的模型仅有以下几个
SUPPORT_AGENT_MODEL = [
"azure-api",
"openai-api",
"qwen-api",
"Qwen",
"chatglm3",
"xinghuo-api",
]

View File

@ -128,6 +128,9 @@ FSCHAT_MODEL_WORKERS = {
"tiangong-api": {
"port": 21009,
},
"gemini-api": {
"port": 21012,
},
}
# fastchat multi model worker server

View File

@ -1,2 +1,4 @@
from .mypdfloader import RapidOCRPDFLoader
from .myimgloader import RapidOCRLoader
from .myimgloader import RapidOCRLoader
from .mydocloader import RapidOCRDocLoader
from .mypptloader import RapidOCRPPTLoader

View File

@ -0,0 +1,71 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm
class RapidOCRDocLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def doc2text(filepath):
from docx.table import _Cell, Table
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.text.paragraph import Paragraph
from docx import Document, ImagePart
from PIL import Image
from io import BytesIO
import numpy as np
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
doc = Document(filepath)
resp = ""
def iter_block_items(parent):
from docx.document import Document
if isinstance(parent, Document):
parent_elm = parent.element.body
elif isinstance(parent, _Cell):
parent_elm = parent._tc
else:
raise ValueError("RapidOCRDocLoader parse fail")
for child in parent_elm.iterchildren():
if isinstance(child, CT_P):
yield Paragraph(child, parent)
elif isinstance(child, CT_Tbl):
yield Table(child, parent)
b_unit = tqdm.tqdm(total=len(doc.paragraphs)+len(doc.tables),
desc="RapidOCRDocLoader block index: 0")
for i, block in enumerate(iter_block_items(doc)):
b_unit.set_description(
"RapidOCRDocLoader block index: {}".format(i))
b_unit.refresh()
if isinstance(block, Paragraph):
resp += block.text.strip() + "\n"
images = block._element.xpath('.//pic:pic') # 获取所有图片
for image in images:
for img_id in image.xpath('.//a:blip/@r:embed'): # 获取图片id
part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
if isinstance(part, ImagePart):
image = Image.open(BytesIO(part._blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif isinstance(block, Table):
for row in block.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
resp += paragraph.text.strip() + "\n"
b_unit.update(1)
return resp
text = doc2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == '__main__':
loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx")
docs = loader.load()
print(docs)

View File

@ -1,5 +1,6 @@
from typing import List
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from configs import PDF_OCR_THRESHOLD
from document_loaders.ocr import get_ocr
import tqdm
@ -15,7 +16,6 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
b_unit = tqdm.tqdm(total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0")
for i, page in enumerate(doc):
# 更新描述
b_unit.set_description("RapidOCRPDFLoader context page index: {}".format(i))
# 立即显示进度条更新结果
@ -24,14 +24,20 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
text = page.get_text("")
resp += text + "\n"
img_list = page.get_images()
img_list = page.get_image_info(xrefs=True)
for img in img_list:
pix = fitz.Pixmap(doc, img[0])
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
result, _ = ocr(img_array)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
if xref := img.get("xref"):
bbox = img["bbox"]
# 检查图片尺寸是否超过设定的阈值
if ((bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[0]
or (bbox[3] - bbox[1]) / (page.rect.height) < PDF_OCR_THRESHOLD[1]):
continue
pix = fitz.Pixmap(doc, xref)
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, -1)
result, _ = ocr(img_array)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
# 更新进度
b_unit.update(1)

View File

@ -0,0 +1,59 @@
from langchain.document_loaders.unstructured import UnstructuredFileLoader
from typing import List
import tqdm
class RapidOCRPPTLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def ppt2text(filepath):
from pptx import Presentation
from PIL import Image
import numpy as np
from io import BytesIO
from rapidocr_onnxruntime import RapidOCR
ocr = RapidOCR()
prs = Presentation(filepath)
resp = ""
def extract_text(shape):
nonlocal resp
if shape.has_text_frame:
resp += shape.text.strip() + "\n"
if shape.has_table:
for row in shape.table.rows:
for cell in row.cells:
for paragraph in cell.text_frame.paragraphs:
resp += paragraph.text.strip() + "\n"
if shape.shape_type == 13: # 13 表示图片
image = Image.open(BytesIO(shape.image.blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif shape.shape_type == 6: # 6 表示组合
for child_shape in shape.shapes:
extract_text(child_shape)
b_unit = tqdm.tqdm(total=len(prs.slides),
desc="RapidOCRPPTLoader slide index: 1")
# 遍历所有幻灯片
for slide_number, slide in enumerate(prs.slides, start=1):
b_unit.set_description(
"RapidOCRPPTLoader slide index: {}".format(slide_number))
b_unit.refresh()
sorted_shapes = sorted(slide.shapes,
key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历
for shape in sorted_shapes:
extract_text(shape)
b_unit.update(1)
return resp
text = ppt2text(self.file_path)
from unstructured.partition.text import partition_text
return partition_text(text=text, **self.unstructured_kwargs)
if __name__ == '__main__':
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx")
docs = loader.load()
print(docs)

View File

@ -7,31 +7,35 @@
保存的模型的位置位于原本嵌入模型的目录下模型的名称为原模型名称+Merge_Keywords_时间戳
'''
import sys
sys.path.append("..")
import os
import torch
from datetime import datetime
from configs import (
MODEL_PATH,
EMBEDDING_MODEL,
EMBEDDING_KEYWORD_FILE,
)
import os
import torch
from safetensors.torch import save_model
from sentence_transformers import SentenceTransformer
from langchain_core._api import deprecated
@deprecated(
since="0.3.0",
message="自定义关键词 Langchain-Chatchat 0.3.x 重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def get_keyword_embedding(bert_model, tokenizer, key_words):
tokenizer_output = tokenizer(key_words, return_tensors="pt", padding=True, truncation=True)
# No need to manually convert to tensor as we've set return_tensors="pt"
input_ids = tokenizer_output['input_ids']
# Remove the first and last token for each sequence in the batch
input_ids = input_ids[:, 1:-1]
keyword_embedding = bert_model.embeddings.word_embeddings(input_ids)
keyword_embedding = torch.mean(keyword_embedding, 1)
return keyword_embedding
@ -47,14 +51,11 @@ def add_keyword_to_model(model_name=EMBEDDING_MODEL, keyword_file: str = "", out
bert_model = word_embedding_model.auto_model
tokenizer = word_embedding_model.tokenizer
key_words_embedding = get_keyword_embedding(bert_model, tokenizer, key_words)
# key_words_embedding = st_model.encode(key_words)
embedding_weight = bert_model.embeddings.word_embeddings.weight
embedding_weight_len = len(embedding_weight)
tokenizer.add_tokens(key_words)
bert_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
# key_words_embedding_tensor = torch.from_numpy(key_words_embedding)
embedding_weight = bert_model.embeddings.word_embeddings.weight
with torch.no_grad():
embedding_weight[embedding_weight_len:embedding_weight_len + key_words_len, :] = key_words_embedding
@ -76,46 +77,3 @@ def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
output_model_path = os.path.join(model_parent_directory, output_model_name)
add_keyword_to_model(model_name, keyword_file, output_model_path)
if __name__ == '__main__':
add_keyword_to_embedding_model(EMBEDDING_KEYWORD_FILE)
# input_model_name = ""
# output_model_path = ""
# # 以下为加入关键字前后tokenizer的测试用例对比
# def print_token_ids(output, tokenizer, sentences):
# for idx, ids in enumerate(output['input_ids']):
# print(f'sentence={sentences[idx]}')
# print(f'ids={ids}')
# for id in ids:
# decoded_id = tokenizer.decode(id)
# print(f' {decoded_id}->{id}')
#
# sentences = [
# '数据科学与大数据技术',
# 'Langchain-Chatchat'
# ]
#
# st_no_keywords = SentenceTransformer(input_model_name)
# tokenizer_without_keywords = st_no_keywords.tokenizer
# print("===== tokenizer with no keywords added =====")
# output = tokenizer_without_keywords(sentences)
# print_token_ids(output, tokenizer_without_keywords, sentences)
# print(f'-------- embedding with no keywords added -----')
# embeddings = st_no_keywords.encode(sentences)
# print(embeddings)
#
# print("--------------------------------------------")
# print("--------------------------------------------")
# print("--------------------------------------------")
#
# st_with_keywords = SentenceTransformer(output_model_path)
# tokenizer_with_keywords = st_with_keywords.tokenizer
# print("===== tokenizer with keyword added =====")
# output = tokenizer_with_keywords(sentences)
# print_token_ids(output, tokenizer_with_keywords, sentences)
#
# print(f'-------- embedding with keywords added -----')
# embeddings = st_with_keywords.encode(sentences)
# print(embeddings)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 266 KiB

BIN
img/qr_code_85.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 272 KiB

View File

@ -6,7 +6,6 @@ from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from datetime import datetime
import sys
if __name__ == "__main__":
@ -50,11 +49,11 @@ if __name__ == "__main__":
)
parser.add_argument(
"-i",
"--increament",
"--increment",
action="store_true",
help=('''
update vector store for files exist in local folder and not exist in database.
use this option if you want to create vectors increamentally.
use this option if you want to create vectors incrementally.
'''
)
)
@ -100,7 +99,7 @@ if __name__ == "__main__":
if args.clear_tables:
reset_tables()
print("database talbes reseted")
print("database tables reset")
if args.recreate_vs:
create_tables()
@ -110,8 +109,8 @@ if __name__ == "__main__":
import_from_db(args.import_db)
elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model)
elif args.increament:
folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model)
elif args.increment:
folder2db(kb_names=args.kb_name, mode="increment", embed_model=args.embed_model)
elif args.prune_db:
prune_db_docs(args.kb_name)
elif args.prune_folder:

View File

@ -1,6 +1,5 @@
# API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/
torch~=2.1.2
torchvision~=0.16.2
torchaudio~=2.1.2
@ -8,30 +7,30 @@ xformers==0.0.23.post1
transformers==4.36.2
sentence_transformers==2.2.2
langchain==0.0.352
langchain==0.0.354
langchain-experimental==0.0.47
pydantic==1.10.13
fschat==0.2.34
openai~=1.6.0
fastapi>=0.105
sse_starlette
fschat==0.2.35
openai~=1.7.1
fastapi~=0.108.0
sse_starlette==1.8.2
nltk>=3.8.1
uvicorn>=0.24.0.post1
starlette~=0.27.0
starlette~=0.32.0
unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
accelerate==0.24.1
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
accelerate~=0.24.1
spacy~=3.7.2
PyMuPDF~=1.23.8
rapidocr_onnxruntime==1.3.8
requests>=2.31.0
pathlib>=1.0.1
pytest>=7.4.3
numexpr>=2.8.6 # max version for py38
strsimpy>=0.2.1
markdownify>=0.11.6
requests~=2.31.0
pathlib~=1.0.1
pytest~=7.4.3
numexpr~=2.8.6 # max version for py38
strsimpy~=0.2.1
markdownify~=0.11.6
tiktoken~=0.5.2
tqdm>=4.66.1
websockets>=12.0
@ -39,22 +38,18 @@ numpy~=1.24.4
pandas~=2.0.3
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux"
httpx[brotli,http2,socks]==0.25.2
llama-index
vllm==0.2.7; sys_platform == "linux"
# optional document loaders
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
#rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2
# Online api libs dependencies
zhipuai>=1.0.7, <=2.0.0 # zhipu
dashscope>=1.13.6 # qwen
# zhipuAI sdk is not supported on our platform, so use http instead
dashscope==1.13.6 # qwen
# volcengine>=1.0.119 # fangzhou
# uncomment libs if you want to use corresponding vector store
@ -64,16 +59,18 @@ dashscope>=1.13.6 # qwen
# Agent and Search Tools
arxiv>=2.0.0
youtube-search>=2.1.2
duckduckgo-search>=3.9.9
metaphor-python>=0.1.23
arxiv~=2.1.0
youtube-search~=2.1.2
duckduckgo-search~=3.9.9
metaphor-python~=0.1.23
# WebUI requirements
streamlit~=1.29.0 # do remember to add streamlit to environment variables if you use windows
streamlit-option-menu>=0.3.6
streamlit==1.30.0
streamlit-option-menu==0.3.6
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
watchdog>=3.0.0
streamlit-modal==0.1.0
streamlit-aggrid==0.3.4.post3
httpx==0.26.0
watchdog==3.0.0

View File

@ -1,6 +1,3 @@
# API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/
torch~=2.1.2
torchvision~=0.16.2
torchaudio~=2.1.2
@ -8,52 +5,52 @@ xformers==0.0.23.post1
transformers==4.36.2
sentence_transformers==2.2.2
langchain==0.0.352
langchain==0.0.354
langchain-experimental==0.0.47
pydantic==1.10.13
fschat==0.2.34
openai~=1.6.0
fastapi>=0.105
sse_starlette
fschat==0.2.35
openai~=1.7.1
fastapi~=0.108.0
sse_starlette==1.8.2
nltk>=3.8.1
uvicorn>=0.24.0.post1
starlette~=0.27.0
starlette~=0.32.0
unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
accelerate==0.24.1
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
accelerate~=0.24.1
spacy~=3.7.2
PyMuPDF~=1.23.8
rapidocr_onnxruntime~=1.3.8
requests>=2.31.0
pathlib>=1.0.1
pytest>=7.4.3
numexpr>=2.8.6 # max version for py38
strsimpy>=0.2.1
markdownify>=0.11.6
rapidocr_onnxruntime==1.3.8
requests~=2.31.0
pathlib~=1.0.1
pytest~=7.4.3
numexpr~=2.8.6 # max version for py38
strsimpy~=0.2.1
markdownify~=0.11.6
tiktoken~=0.5.2
tqdm>=4.66.1
websockets>=12.0
numpy~=1.26.2
pandas~=2.1.4
numpy~=1.24.4
pandas~=2.0.3
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.6; sys_platform == "linux"
httpx[brotli,http2,socks]~=0.25.2
vllm==0.2.7; sys_platform == "linux"
httpx==0.26.0
llama-index
# optional document loaders
rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
# rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq==1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2
# Online api libs dependencies
zhipuai>=1.0.7<=2.0.0 # zhipu
dashscope>=1.13.6 # qwen
# zhipuAI sdk is not supported on our platform, so use http instead
dashscope==1.13.6 # qwen
# volcengine>=1.0.119 # fangzhou
# uncomment libs if you want to use corresponding vector store
@ -63,7 +60,7 @@ dashscope>=1.13.6 # qwen
# Agent and Search Tools
arxiv>=2.0.0
youtube-search>=2.1.2
duckduckgo-search>=3.9.9
metaphor-python>=0.1.23
arxiv~=2.1.0
youtube-search~=2.1.2
duckduckgo-search~=3.9.9
metaphor-python~=0.1.23

View File

@ -1,60 +1,44 @@
# API requirements
# On Windows system, install the cuda version manually from https://pytorch.org/
# torch~=2.1.2
# torchvision~=0.16.2
# torchaudio~=2.1.2
# xformers==0.0.23.post1
# transformers==4.36.2
# sentence_transformers==2.2.2
langchain==0.0.352
langchain==0.0.354
langchain-experimental==0.0.47
pydantic==1.10.13
fschat==0.2.34
openai~=1.6.0
fastapi>=0.105
sse_starlette
fschat==0.2.35
openai~=1.7.1
fastapi~=0.108.0
sse_starlette==1.8.2
nltk>=3.8.1
uvicorn>=0.24.0.post1
starlette~=0.27.0
unstructured[docx,csv]==0.11.0 # add pdf if need
starlette~=0.32.0
unstructured[all-docs]==0.11.0
python-magic-bin; sys_platform == 'win32'
SQLAlchemy==2.0.19
faiss-cpu~=1.7.4 # `conda install faiss-gpu -c conda-forge` if you want to accelerate with gpus
# accelerate==0.24.1
# spacy~=3.7.2
# PyMuPDF~=1.23.8
# rapidocr_onnxruntime~=1.3.8
requests>=2.31.0
pathlib>=1.0.1
pytest>=7.4.3
numexpr>=2.8.6 # max version for py38
strsimpy>=0.2.1
markdownify>=0.11.6
# tiktoken~=0.5.2
faiss-cpu~=1.7.4
requests~=2.31.0
pathlib~=1.0.1
pytest~=7.4.3
numexpr~=2.8.6 # max version for py38
strsimpy~=0.2.1
markdownify~=0.11.6
tiktoken~=0.5.2
tqdm>=4.66.1
websockets>=12.0
numpy~=1.26.2
pandas~=2.1.4
# einops>=0.7.0
# transformers_stream_generator==0.0.4
# vllm==0.2.6; sys_platform == "linux"
httpx[brotli,http2,socks]~=0.25.2
numpy~=1.24.4
pandas~=2.0.3
einops>=0.7.0
transformers_stream_generator==0.0.4
vllm==0.2.7; sys_platform == "linux"
httpx[brotli,http2,socks]==0.25.2
requests
pathlib
pytest
# optional document loaders
rapidocr_paddle[gpu]>=1.3.0.post5 # gpu accelleration for ocr of pdf and image files
jq>=1.6.0 # for .json and .jsonl files. suggest `conda install jq` on windows
# html2text # for .enex files
beautifulsoup4~=4.12.2 # for .mhtml files
pysrt~=1.1.2
# Online api libs dependencies
zhipuai>=1.0.7<=2.0.0 # zhipu
dashscope>=1.13.6 # qwen
# volcengine>=1.0.119 # fangzhou
# zhipuAI sdk is not supported on our platform, so use http instead
dashscope==1.13.6
# volcengine>=1.0.119
# uncomment libs if you want to use corresponding vector store
# pymilvus>=2.3.4
@ -63,17 +47,18 @@ dashscope>=1.13.6 # qwen
# Agent and Search Tools
arxiv>=2.0.0
youtube-search>=2.1.2
duckduckgo-search>=3.9.9
metaphor-python>=0.1.23
arxiv~=2.1.0
youtube-search~=2.1.2
duckduckgo-search~=3.9.9
metaphor-python~=0.1.23
# WebUI requirements
streamlit~=1.29.0 # do remember to add streamlit to environment variables if you use windows
streamlit-option-menu>=0.3.6
streamlit==1.30.0
streamlit-option-menu==0.3.6
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0
streamlit-modal==0.1.0
streamlit-aggrid==0.3.4.post3
httpx==0.26.0
watchdog==3.0.0

View File

@ -1,9 +1,10 @@
# WebUI requirements
streamlit~=1.29.0 # do remember to add streamlit to environment variables if you use windows
streamlit-option-menu>=0.3.6
streamlit==1.30.0
streamlit-option-menu==0.3.6
streamlit-antd-components==0.3.1
streamlit-chatbox==1.1.11
streamlit-modal>=0.1.0
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]>=0.25.2
watchdog>=3.0.0
streamlit-modal==0.1.0
streamlit-aggrid==0.3.4.post3
httpx==0.26.0
watchdog==3.0.0

View File

@ -1,22 +1,19 @@
"""
This file is a modified version for ChatGLM3-6B the original ChatGLM3Agent.py file from the langchain repo.
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
"""
from __future__ import annotations
import yaml
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from typing import Any, List, Sequence, Tuple, Optional, Union
import os
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate, MessagesPlaceholder,
)
import json
import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
from pydantic.schema import model_schema
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.agents.agent import AgentOutputParser
from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field
@ -43,12 +40,18 @@ class StructuredChatOutputParserWithRetries(AgentOutputParser):
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index]
if "tool_call" in text:
tool_name_end = text.find("```")
tool_name = text[:tool_name_end].strip()
input_para = text.split("='")[-1].split("'")[0]
action_end = text.find("```")
action = text[:action_end].strip()
params_str_start = text.find("(") + 1
params_str_end = text.rfind(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
action_json = {
"action": tool_name,
"action_input": input_para
"action": action,
"action_input": params
}
else:
action_json = {
@ -109,10 +112,6 @@ class StructuredGLM3ChatAgent(Agent):
else:
return agent_scratchpad
@classmethod
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
pass
@classmethod
def _get_default_output_parser(
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
@ -121,7 +120,7 @@ class StructuredGLM3ChatAgent(Agent):
@property
def _stop(self) -> List[str]:
return ["```<observation>"]
return ["<|observation|>"]
@classmethod
def create_prompt(
@ -131,44 +130,25 @@ class StructuredGLM3ChatAgent(Agent):
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
def tool_config_from_file(tool_name, directory="server/agent/tools/"):
"""search tool yaml and return simplified json format"""
file_path = os.path.join(directory, f"{tool_name.lower()}.yaml")
try:
with open(file_path, 'r', encoding='utf-8') as file:
tool_config = yaml.safe_load(file)
# Simplify the structure if needed
simplified_config = {
"name": tool_config.get("name", ""),
"description": tool_config.get("description", ""),
"parameters": tool_config.get("parameters", {})
}
return simplified_config
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
return None
except Exception as e:
logger.error(f"An error occurred while reading {file_path}: {e}")
return None
tools_json = []
tool_names = []
for tool in tools:
tool_config = tool_config_from_file(tool.name)
if tool_config:
tools_json.append(tool_config)
tool_names.append(tool.name)
# Format the tools for output
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
simplified_config_langchain = {
"name": tool.name,
"description": tool.description,
"parameters": tool_schema.get("properties", {})
}
tools_json.append(simplified_config_langchain)
tool_names.append(tool.name)
formatted_tools = "\n".join([
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
for tool in tools_json
])
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
template = prompt.format(tool_names=tool_names,
tools=formatted_tools,
history="{history}",
history="None",
input="{input}",
agent_scratchpad="{agent_scratchpad}")
@ -225,7 +205,6 @@ def initialize_glm3_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None,
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
@ -238,14 +217,12 @@ def initialize_glm3_agent(
llm=llm,
tools=tools,
prompt=prompt,
callback_manager=callback_manager, **agent_kwargs
**agent_kwargs
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
callback_manager=callback_manager,
memory=memory,
tags=tags_,
**kwargs,
)
)

View File

@ -1,5 +1,3 @@
## 由于工具类无法传参,所以使用全局变量来传递模型和对应的知识库介绍
class ModelContainer:
def __init__(self):
self.MODEL = None

View File

@ -3,7 +3,7 @@ from .search_knowledgebase_simple import search_knowledgebase_simple
from .search_knowledgebase_once import search_knowledgebase_once, KnowledgeSearchInput
from .search_knowledgebase_complex import search_knowledgebase_complex, KnowledgeSearchInput
from .calculate import calculate, CalculatorInput
from .weather_check import weathercheck, WhetherSchema
from .weather_check import weathercheck, WeatherInput
from .shell import shell, ShellInput
from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram, WolframInput

View File

@ -1,10 +0,0 @@
name: arxiv
description: A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.
parameters:
type: object
properties:
query:
type: string
description: The search query title
required:
- query

View File

@ -1,10 +0,0 @@
name: calculate
description: Useful for when you need to answer questions about simple calculations
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -1,10 +0,0 @@
name: search_internet
description: Use this tool to surf internet and get information
parameters:
type: object
properties:
query:
type: string
description: Query for Internet search
required:
- query

View File

@ -1,10 +0,0 @@
name: search_knowledgebase_complex
description: Use this tool to search local knowledgebase and get information
parameters:
type: object
properties:
query:
type: string
description: The query to be searched
required:
- query

View File

@ -1,10 +0,0 @@
name: search_youtube
description: Use this tools to search youtube videos
parameters:
type: object
properties:
query:
type: string
description: Query for Videos search
required:
- query

View File

@ -1,10 +0,0 @@
name: shell
description: Use Linux Shell to execute Linux commands
parameters:
type: object
properties:
query:
type: string
description: The command to execute
required:
- query

View File

@ -1,338 +1,25 @@
from __future__ import annotations
## 单独运行的时候需要添加
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
import requests
from typing import List, Any, Optional
from datetime import datetime
from langchain.prompts import PromptTemplate
from server.agent import model_container
from pydantic import BaseModel, Field
## 使用和风天气API查询天气
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
# key长这样这里提供了示例的key这个key没法使用你需要自己去注册和风天气的账号然后在这里填入你的key
_PROMPT_TEMPLATE = """
用户会提出一个关于天气的问题你的目标是拆分出用户问题中的区 并按照我提供的工具回答
例如 用户提出的问题是: 上海浦东未来1小时天气情况
提取的市和区是: 上海 浦东
如果用户提出的问题是: 上海未来1小时天气情况
提取的市和区是: 上海 None
请注意以下内容:
1. 如果你没有找到区的内容,则一定要使用 None 替代否则程序无法运行
2. 如果用户没有指定市 则直接返回缺少信息
问题: ${{用户的问题}}
你的回答格式应该按照下面的内容请注意格式内的```text 等标记都必须输出这是我用来提取答案的标记
```text
${{拆分的市和区中间用空格隔开}}
```
... weathercheck( )...
```output
${{提取后的答案}}
```
答案: ${{答案}}
这是一个例子
问题: 上海浦东未来1小时天气情况
```text
上海 浦东
```
...weathercheck(上海 浦东)...
```output
预报时间: 1小时后
具体时间: 今天 18:00
温度: 24°C
天气: 多云
风向: 西南风
风速: 7
湿度: 88%
降水概率: 16%
Answer: 上海浦东一小时后的天气是多云
现在这是我的问题
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)
更简单的单参数输入工具实现用于查询现在天气的情况
"""
from pydantic import BaseModel, Field
import requests
def weather(location: str, api_key: str):
url = f"https://api.seniverse.com/v3/weather/now.json?key={api_key}&location={location}&language=zh-Hans&unit=c"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
weather = {
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
return weather
else:
raise Exception(
f"Failed to retrieve weather: {response.status_code}")
def get_city_info(location, adm, key):
base_url = 'https://geoapi.qweather.com/v2/city/lookup?'
params = {'location': location, 'adm': adm, 'key': key}
response = requests.get(base_url, params=params)
data = response.json()
return data
def format_weather_data(data, place):
hourly_forecast = data['hourly']
formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n"
for forecast in hourly_forecast:
# 将预报时间转换为datetime对象
forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z')
# 获取预报时间的时区
forecast_tz = forecast_time.tzinfo
# 获取当前时间(使用预报时间的时区)
now = datetime.now(forecast_tz)
# 计算预报日期与当前日期的差值
days_diff = (forecast_time.date() - now.date()).days
if days_diff == 0:
forecast_date_str = '今天'
elif days_diff == 1:
forecast_date_str = '明天'
elif days_diff == 2:
forecast_date_str = '后天'
else:
forecast_date_str = str(days_diff) + '天后'
forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M')
# 计算预报时间与当前时间的差值
time_diff = forecast_time - now
# 将差值转换为小时
hours_diff = time_diff.total_seconds() // 3600
if hours_diff < 1:
hours_diff_str = '1小时后'
elif hours_diff >= 24:
# 如果超过24小时转换为天数
days_diff = hours_diff // 24
hours_diff_str = str(int(days_diff)) + ''
else:
hours_diff_str = str(int(hours_diff)) + '小时'
# 将预报时间和当前时间的差值添加到输出中
formatted_data += '预报时间: ' + forecast_time_str + ' 距离现在有: ' + hours_diff_str + '\n'
formatted_data += '温度: ' + forecast['temp'] + '°C\n'
formatted_data += '天气: ' + forecast['text'] + '\n'
formatted_data += '风向: ' + forecast['windDir'] + '\n'
formatted_data += '风速: ' + forecast['windSpeed'] + '\n'
formatted_data += '湿度: ' + forecast['humidity'] + '%\n'
formatted_data += '降水概率: ' + forecast['pop'] + '%\n'
# formatted_data += '降水量: ' + forecast['precip'] + 'mm\n'
formatted_data += '\n'
return formatted_data
def get_weather(key, location_id, place):
url = "https://devapi.qweather.com/v7/weather/24h?"
params = {
'location': location_id,
'key': key,
}
response = requests.get(url, params=params)
data = response.json()
return format_weather_data(data, place)
def split_query(query):
parts = query.split()
adm = parts[0]
if len(parts) == 1:
return adm, adm
location = parts[1] if parts[1] != 'None' else adm
return location, adm
def weather(query):
location, adm = split_query(query)
key = KEY
if key == "":
return "请先在代码中填入和风天气API Key"
try:
city_info = get_city_info(location=location, adm=adm, key=key)
location_id = city_info['location'][0]['id']
place = adm + "" + location + ""
weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "以上是查询到的天气信息,请你查收\n"
except KeyError:
try:
city_info = get_city_info(location=adm, adm=adm, key=key)
location_id = city_info['location'][0]['id']
place = adm + ""
weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n"
except KeyError:
return "输入的地区不存在,无法提供天气预报"
class LLMWeatherChain(Chain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""[Deprecated] Prompt to use to translate to python if necessary."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMWeatherChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, expression: str) -> str:
try:
output = weather(expression)
except Exception as e:
output = "输入的信息有误,请再次尝试"
return output
def _process_llm_result(
self, llm_output: str, run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
return {self.output_key: f"输入的格式不对: {llm_output},应该输入 (市 区)的组合"}
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
llm_output = self.llm_chain.predict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
llm_output = await self.llm_chain.apredict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, _run_manager)
@property
def _chain_type(self) -> str:
return "llm_weather_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMWeatherChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
def weathercheck(query: str):
model = model_container.MODEL
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_weather.run(query)
return ans
class WhetherSchema(BaseModel):
location: str = Field(description="应该是一个地区的名称,用空格隔开,例如:上海 浦东,如果没有区的信息,可以只输入上海")
if __name__ == '__main__':
result = weathercheck("苏州姑苏区今晚热不热?")
def weathercheck(location: str):
return weather(location, "your keys")
class WeatherInput(BaseModel):
location: str = Field(description="City name,include city and county")

View File

@ -1,10 +0,0 @@
name: weather_check
description: Use Weather API to get weather information
parameters:
type: object
properties:
query:
type: string
description: City name,include city and county,like "厦门市思明区"
required:
- query

View File

@ -1,10 +0,0 @@
name: wolfram
description: Useful for when you need to calculate difficult math formulas
parameters:
type: object
properties:
query:
type: string
description: The formula to be calculated
required:
- query

View File

@ -1,8 +1,6 @@
from langchain.tools import Tool
from server.agent.tools import *
## 请注意如果你是为了使用AgentLM在这里你应该使用英文版本。
tools = [
Tool.from_function(
func=calculate,
@ -20,7 +18,7 @@ tools = [
func=weathercheck,
name="weather_check",
description="",
args_schema=WhetherSchema,
args_schema=WeatherInput,
),
Tool.from_function(
func=shell,

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult

View File

@ -1,23 +1,23 @@
from langchain.memory import ConversationBufferWindowMemory
import json
import asyncio
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional
import asyncio
from typing import List
from server.chat.utils import History
import json
from server.agent import model_container
from server.knowledge_base.kb_service.base import get_kb_details
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from typing import AsyncIterable, Optional, List
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from server.knowledge_base.kb_service.base import get_kb_details
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from server.chat.utils import History
from server.agent import model_container
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
@ -33,7 +33,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
):
history = [History.from_data(h) for h in history]
@ -55,12 +54,10 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
callbacks=[callback],
)
## 传入全局变量来实现agent调用
kb_list = {x["kb_name"]: x for x in get_kb_details()}
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
if Agent_MODEL:
## 如果有指定使用Agent模型来完成任务
model_agent = get_ChatOpenAI(
model_name=Agent_MODEL,
temperature=temperature,
@ -79,15 +76,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
)
output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
# 把history转成agent的memory
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
memory.chat_memory.add_user_message(message.content)
else:
# 添加AI消息
memory.chat_memory.add_ai_message(message.content)
if "chatglm3" in model_container.MODEL.model_name:
@ -95,7 +88,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
llm=model,
tools=tools,
callback_manager=None,
# Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent.
prompt=prompt_template,
input_variables=["input", "intermediate_steps", "history"],
memory=memory,
@ -155,7 +147,6 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
answer = ""
final_answer = ""
async for chunk in callback.aiter():
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
@ -181,7 +172,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
await task
return EventSourceResponse(agent_chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)
history=history,
model_name=model_name,
prompt_name=prompt_name),
)

View File

@ -1,23 +1,23 @@
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from sse_starlette import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, OVERLAP_SIZE)
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import List, Optional, Dict
from server.chat.utils import History
from langchain.docstore.document import Document
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from sse_starlette import EventSourceResponse
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from server.chat.utils import History
from typing import AsyncIterable
import asyncio
import json
from typing import List, Optional, Dict
from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from markdownify import markdownify
@ -38,11 +38,11 @@ def duckduckgo_search(text, result_len=SEARCH_ENGINE_TOP_K, **kwargs):
def metaphor_search(
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
text: str,
result_len: int = SEARCH_ENGINE_TOP_K,
split_result: bool = False,
chunk_size: int = 500,
chunk_overlap: int = OVERLAP_SIZE,
) -> List[Dict]:
from metaphor_python import Metaphor
@ -58,13 +58,13 @@ def metaphor_search(
# metaphor 返回的内容都是长文本,需要分词再检索
if split_result:
docs = [Document(page_content=x.extract,
metadata={"link": x.url, "title": x.title})
metadata={"link": x.url, "title": x.title})
for x in contents]
text_splitter = RecursiveCharacterTextSplitter(["\n\n", "\n", ".", " "],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
splitted_docs = text_splitter.split_documents(docs)
# 将切分好的文档放入临时向量库重新筛选出TOP_K个文档
if len(splitted_docs) > result_len:
normal = NormalizedLevenshtein()
@ -74,13 +74,13 @@ def metaphor_search(
splitted_docs = splitted_docs[:result_len]
docs = [{"snippet": x.page_content,
"link": x.metadata["link"],
"title": x.metadata["title"]}
"link": x.metadata["link"],
"title": x.metadata["title"]}
for x in splitted_docs]
else:
docs = [{"snippet": x.extract,
"link": x.url,
"title": x.title}
"link": x.url,
"title": x.title}
for x in contents]
return docs
@ -113,25 +113,27 @@ async def lookup_search_engine(
docs = search_result2docs(results)
return docs
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False, description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
):
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None,
description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
split_result: bool = Body(False,
description="是否对搜索结果进行拆分主要用于metaphor搜索引擎")
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -198,9 +200,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
await task
return EventSourceResponse(search_engine_chat_iterator(query=query,
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
)

View File

@ -83,7 +83,7 @@ def add_file_to_db(session,
kb_file: KnowledgeFile,
docs_count: int = 0,
custom_docs: bool = False,
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...]
doc_infos: List[Dict] = [], # 形式:[{"id": str, "metadata": dict}, ...]
):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb:

View File

@ -24,7 +24,7 @@ from server.knowledge_base.utils import (
list_kbs_from_folder, list_files_from_folder,
)
from typing import List, Union, Dict, Optional
from typing import List, Union, Dict, Optional, Tuple
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
@ -261,7 +261,7 @@ class KBService(ABC):
query: str,
top_k: int,
score_threshold: float,
) -> List[Document]:
) -> List[Tuple[Document, float]]:
"""
搜索知识库子类实自己逻辑
"""

View File

@ -6,6 +6,7 @@ from langchain.schema import Document
from langchain.vectorstores.elasticsearch import ElasticsearchStore
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile
from server.utils import load_local_embeddings
from elasticsearch import Elasticsearch,BadRequestError
from configs import logger
@ -15,7 +16,7 @@ class ESKBService(KBService):
def do_init(self):
self.kb_path = self.get_kb_path(self.kb_name)
self.index_name = self.kb_path.split("/")[-1]
self.index_name = os.path.split(self.kb_path)[-1]
self.IP = kbs_config[self.vs_type()]['host']
self.PORT = kbs_config[self.vs_type()]['port']
self.user = kbs_config[self.vs_type()].get("user",'')
@ -38,7 +39,16 @@ class ESKBService(KBService):
raise e
try:
# 首先尝试通过es_client_python创建
self.es_client_python.indices.create(index=self.index_name)
mappings = {
"properties": {
"dense_vector": {
"type": "dense_vector",
"dims": self.dims_length,
"index": True
}
}
}
self.es_client_python.indices.create(index=self.index_name, mappings=mappings)
except BadRequestError as e:
logger.error("创建索引失败,重新")
logger.error(e)
@ -80,9 +90,9 @@ class ESKBService(KBService):
except Exception as e:
logger.error("创建索引失败...")
logger.error(e)
# raise e
# raise e
@staticmethod
def get_kb_path(knowledge_base_name: str):
@ -220,7 +230,12 @@ class ESKBService(KBService):
shutil.rmtree(self.kb_path)
if __name__ == '__main__':
esKBService = ESKBService("test")
#esKBService.clear_vs()
#esKBService.create_kb()
esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test"))
print(esKBService.search_docs("如何启动api服务"))

View File

@ -7,7 +7,7 @@ from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafe
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import torch_gc
from langchain.docstore.document import Document
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Tuple
class FaissKBService(KBService):
@ -61,7 +61,7 @@ class FaissKBService(KBService):
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
) -> List[Document]:
) -> List[Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:

View File

@ -18,13 +18,10 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
# def save_vector_store(self):
# if self.milvus.col:
# self.milvus.col.flush()
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
@ -53,7 +50,7 @@ class MilvusKBService(KBService):
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"),
index_params=kbs_config.ge("milvus_kwargs")["index_params"],
index_params=kbs_config.get("milvus_kwargs")["index_params"],
search_params=kbs_config.get("milvus_kwargs")["search_params"]
)

View File

@ -11,22 +11,26 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
import shutil
import sqlalchemy
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
class PGKBService(KBService):
pg_vector: PGVector
engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
def _load_pg_vector(self):
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection=PGKBService.engine,
connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.pg_vector.connect() as connect:
with Session(PGKBService.engine) as session:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
results = [Document(page_content=row[0], metadata=row[1]) for row in
connect.execute(stmt, parameters={'ids': ids}).fetchall()]
session.execute(stmt, {'ids': ids}).fetchall()]
return results
# TODO:
@ -43,8 +47,8 @@ class PGKBService(KBService):
return SupportedVSType.PG
def do_drop_kb(self):
with self.pg_vector.connect() as connect:
connect.execute(text(f'''
with Session(PGKBService.engine) as session:
session.execute(text(f'''
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
DELETE FROM langchain_pg_embedding
WHERE collection_id IN (
@ -53,11 +57,10 @@ class PGKBService(KBService):
-- 删除 langchain_pg_collection 表中 记录
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
'''))
connect.commit()
session.commit()
shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_pg_vector()
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
@ -69,13 +72,13 @@ class PGKBService(KBService):
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:
with Session(PGKBService.engine) as session:
filepath = kb_file.filepath.replace('\\', '\\\\')
connect.execute(
session.execute(
text(
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
"filepath", filepath)))
connect.commit()
session.commit()
def do_clear_vs(self):
self.pg_vector.delete_collection()

View File

@ -16,13 +16,10 @@ class ZillizKBService(KBService):
from pymilvus import Collection
return Collection(zilliz_name)
# def save_vector_store(self):
# if self.zilliz.col:
# self.zilliz.col.flush()
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.zilliz.col:
# ids = [int(id) for id in ids] # for zilliz if needed #pr 2725
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
@ -50,8 +47,7 @@ class ZillizKBService(KBService):
def _load_zilliz(self):
zilliz_args = kbs_config.get("zilliz")
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, connection_args=zilliz_args)
collection_name=self.kb_name, connection_args=zilliz_args)
def do_init(self):
self._load_zilliz()
@ -95,9 +91,7 @@ class ZillizKBService(KBService):
if __name__ == '__main__':
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
zillizService = ZillizKBService("test")

View File

@ -84,7 +84,7 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
def folder2db(
kb_names: List[str],
mode: Literal["recreate_vs", "update_in_db", "increament"],
mode: Literal["recreate_vs", "update_in_db", "increment"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size: int = CHUNK_SIZE,
@ -97,7 +97,7 @@ def folder2db(
recreate_vs: recreate all vector store and fill info to database using existed files in local folder
fill_info_only(disabled): do not create vector store, fill info to db using existed files only
update_in_db: update vector store and database info using local files that existed in database only
increament: create vector store and database info for local files that not existed in database only
increment: create vector store and database info for local files that not existed in database only
"""
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]):
@ -142,7 +142,7 @@ def folder2db(
files2vs(kb_name, kb_files)
kb.save_vector_store()
# 对比本地目录与数据库中的文件列表,进行增量向量化
elif mode == "increament":
elif mode == "increment":
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
@ -150,7 +150,7 @@ def folder2db(
files2vs(kb_name, kb_files)
kb.save_vector_store()
else:
print(f"unspported migrate mode: {mode}")
print(f"unsupported migrate mode: {mode}")
def prune_db_docs(kb_names: List[str]):

View File

@ -91,9 +91,14 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"],
"CSVLoader": [".csv"],
# "FilteredCSVLoader": [".csv"], # 需要自己指定,目前还没有支持
# "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
"RapidOCRPDFLoader": [".pdf"],
"RapidOCRDocLoader": ['.docx', '.doc'],
"RapidOCRPPTLoader": ['.ppt', '.pptx', ],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.txt', '.xml',
'.epub', '.odt','.tsv'],
"UnstructuredEmailLoader": ['.eml', '.msg'],
"UnstructuredEPubLoader": ['.epub'],
"UnstructuredExcelLoader": ['.xlsx', '.xls', '.xlsd'],
@ -109,7 +114,6 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'],
"UnstructuredXMLLoader": ['.xml'],
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
"EverNoteLoader": ['.enex'],
"UnstructuredFileLoader": ['.txt'],
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
@ -141,15 +145,14 @@ def get_LoaderClass(file_extension):
if file_extension in extensions:
return LoaderClass
# 把一些向量化共用逻辑从KnowledgeFile抽取出来等langchain支持内存文件的时候可以将非磁盘文件向量化
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
'''
根据loader_name和文件路径或内容返回文档加载器
'''
loader_kwargs = loader_kwargs or {}
try:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]:
if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader", "FilteredCSVLoader",
"RapidOCRDocLoader", "RapidOCRPPTLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
@ -258,7 +261,11 @@ def make_text_splitter(
print(e)
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
# If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
# text_splitter._tokenizer.max_length = 37016792
# text_splitter._tokenizer.prefer_gpu()
return text_splitter

View File

@ -0,0 +1,51 @@
from typing import (
TYPE_CHECKING,
Any,
Tuple
)
import sys
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
import tiktoken
class MinxChatOpenAI:
@staticmethod
def import_tiktoken() -> Any:
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install tiktoken`."
)
return tiktoken
@staticmethod
def get_encoding_model(self) -> Tuple[str, "tiktoken.Encoding"]:
tiktoken_ = MinxChatOpenAI.import_tiktoken()
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time.
# Returning num tokens assuming gpt-3.5-turbo-0301.
model = "gpt-3.5-turbo-0301"
elif model == "gpt-4":
# gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model(model)
except Exception as e:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding

View File

@ -8,3 +8,4 @@ from .qwen import QwenWorker
from .baichuan import BaiChuanWorker
from .azure import AzureWorker
from .tiangong import TianGongWorker
from .gemini import GeminiWorker

View File

@ -0,0 +1,124 @@
import sys
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import json,httpx
from typing import List, Dict
from configs import logger, log_verbose
class GeminiWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["Gemini-api"],
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs)
def create_gemini_messages(self,messages) -> json:
has_history = any(msg['role'] == 'assistant' for msg in messages)
gemini_msg = []
for msg in messages:
role = msg['role']
content = msg['content']
if role == 'system':
continue
if has_history:
if role == 'assistant':
role = "model"
transformed_msg = {"role": role, "parts": [{"text": content}]}
else:
if role == 'user':
transformed_msg = {"parts": [{"text": content}]}
gemini_msg.append(transformed_msg)
msg = dict(contents=gemini_msg)
return msg
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
data = self.create_gemini_messages(messages=params.messages)
generationConfig=dict(
temperature=params.temperature,
topK=1,
topP=1,
maxOutputTokens=4096,
stopSequences=[]
)
data['generationConfig'] = generationConfig
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"+ '?key=' + params.api_key
headers = {
'Content-Type': 'application/json',
}
if log_verbose:
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')
logger.info(f'{self.__class__.__name__}:data: {data}')
text = ""
json_string = ""
timeout = httpx.Timeout(60.0)
client=get_httpx_client(timeout=timeout)
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()
if not line or "[DONE]" in line:
continue
json_string += line
try:
resp = json.loads(json_string)
if 'candidates' in resp:
for candidate in resp['candidates']:
content = candidate.get('content', {})
parts = content.get('parts', [])
for part in parts:
if 'text' in part:
text += part['text']
yield {
"error_code": 0,
"text": text
}
print(text)
except json.JSONDecodeError as e:
print("Failed to decode JSON:", e)
print("Invalid JSON string:", json_string)
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="You are a helpful, respectful and honest assistant.",
messages=[],
roles=["user", "assistant"],
sep="\n### ",
stop_str="###",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.base_model_worker import app
worker = GeminiWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21012",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21012)

View File

@ -28,7 +28,7 @@ class MiniMaxWorker(ApiModelWorker):
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
role_maps = {
"user": self.user_role,
"USER": self.user_role,
"assistant": self.ai_role,
"system": "system",
}
@ -73,7 +73,7 @@ class MiniMaxWorker(ApiModelWorker):
with response as r:
text = ""
for e in r.iter_text():
if not e.startswith("data: "): # 真是优秀的返回
if not e.startswith("data: "):
data = {
"error_code": 500,
"text": f"minimax返回错误的结果{e}",
@ -140,7 +140,7 @@ class MiniMaxWorker(ApiModelWorker):
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
i += batch_size
return {"code": 200, "data": embeddings}
return {"code": 200, "data": result}
def get_embeddings(self, params):
# TODO: 支持embeddings

View File

@ -84,30 +84,6 @@ class QianFanWorker(ApiModelWorker):
def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
# import qianfan
# comp = qianfan.ChatCompletion(model=params.version,
# endpoint=params.version_url,
# ak=params.api_key,
# sk=params.secret_key,)
# text = ""
# for resp in comp.do(messages=params.messages,
# temperature=params.temperature,
# top_p=params.top_p,
# stream=True):
# if resp.code == 200:
# if chunk := resp.body.get("result"):
# text += chunk
# yield {
# "error_code": 0,
# "text": text
# }
# else:
# yield {
# "error_code": resp.code,
# "text": str(resp.body),
# }
BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \
'/{model_version}?access_token={access_token}'
@ -190,19 +166,19 @@ class QianFanWorker(ApiModelWorker):
i = 0
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i+batch_size]
texts = params.texts[i:i + batch_size]
resp = client.post(url, json={"input": texts}).json()
if "error_code" in resp:
data = {
"code": resp["error_code"],
"msg": resp["error_msg"],
"error": {
"message": resp["error_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
"code": resp["error_code"],
"msg": resp["error_msg"],
"error": {
"message": resp["error_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求千帆 API 时发生错误:{data}")
return data
else:

View File

@ -11,16 +11,15 @@ from typing import List, Literal, Dict
import requests
class TianGongWorker(ApiModelWorker):
def __init__(
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["tiangong-api"],
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
**kwargs,
self,
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["tiangong-api"],
version: Literal["SkyChat-MegaVerse"] = "SkyChat-MegaVerse",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
@ -34,18 +33,18 @@ class TianGongWorker(ApiModelWorker):
data = {
"messages": params.messages,
"model": "SkyChat-MegaVerse"
}
timestamp = str(int(time.time()))
sign_content = params.api_key + params.secret_key + timestamp
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
headers={
}
timestamp = str(int(time.time()))
sign_content = params.api_key + params.secret_key + timestamp
sign_result = hashlib.md5(sign_content.encode('utf-8')).hexdigest()
headers = {
"app_key": params.api_key,
"timestamp": timestamp,
"sign": sign_result,
"Content-Type": "application/json",
"stream": "true" # or change to "false" 不处理流式返回内容
"stream": "true" # or change to "false" 不处理流式返回内容
}
# 发起请求并获取响应
response = requests.post(url, headers=headers, json=data, stream=True)
@ -56,17 +55,17 @@ class TianGongWorker(ApiModelWorker):
# 处理接收到的数据
# print(line.decode('utf-8'))
resp = json.loads(line)
if resp["code"] == 200:
if resp["code"] == 200:
text += resp['resp_data']['reply']
yield {
"error_code": 0,
"text": text
}
}
else:
data = {
"error_code": resp["code"],
"text": resp["code_msg"]
}
}
self.logger.error(f"请求天工 API 时出错:{data}")
yield data
@ -85,5 +84,3 @@ class TianGongWorker(ApiModelWorker):
sep="\n### ",
stop_str="###",
)

View File

@ -37,7 +37,7 @@ class XingHuoWorker(ApiModelWorker):
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000需要自行修改
kwargs.setdefault("context_len", 8000)
super().__init__(**kwargs)
self.version = version

View File

@ -4,93 +4,86 @@ from fastchat import conversation as conv
import sys
from typing import List, Dict, Iterator, Literal
from configs import logger, log_verbose
import requests
import jwt
import time
import json
def generate_token(apikey: str, exp_seconds: int):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)
payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}
return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
class ChatGLMWorker(ApiModelWorker):
DEFAULT_EMBED_MODEL = "text_embedding"
def __init__(
self,
*,
model_names: List[str] = ["zhipu-api"],
controller_addr: str = None,
worker_addr: str = None,
version: Literal["chatglm_turbo"] = "chatglm_turbo",
**kwargs,
self,
*,
model_names: List[str] = ["zhipu-api"],
controller_addr: str = None,
worker_addr: str = None,
version: Literal["chatglm_turbo"] = "chatglm_turbo",
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 32768)
kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs)
self.version = version
def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:
# TODO: 维护request_id
import zhipuai
params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
response = zhipuai.model_api.sse_invoke(
model=params.version,
prompt=params.messages,
temperature=params.temperature,
top_p=params.top_p,
incremental=False,
)
for e in response.events():
if e.event == "add":
yield {"error_code": 0, "text": e.data}
elif e.event in ["error", "interrupted"]:
data = {
"error_code": 500,
"text": e.data,
"error": {
"message": e.data,
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求智谱 API 时发生错误:{data}")
yield data
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
import zhipuai
params.load_config(self.model_names[0])
zhipuai.api_key = params.api_key
embeddings = []
try:
for t in params.texts:
response = zhipuai.model_api.invoke(model=params.embed_model or self.DEFAULT_EMBED_MODEL, prompt=t)
if response["code"] == 200:
embeddings.append(response["data"]["embedding"])
else:
self.logger.error(f"请求智谱 API 时发生错误:{response}")
return response # dict with code & msg
except Exception as e:
self.logger.error(f"请求智谱 API 时发生错误:{data}")
data = {"code": 500, "msg": f"对文本向量化时出错:{e}"}
return data
return {"code": 200, "data": embeddings}
token = generate_token(params.api_key, 60)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}"
}
data = {
"model": params.version,
"messages": params.messages,
"max_tokens": params.max_tokens,
"temperature": params.temperature,
"stream": True
}
url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
response = requests.post(url, headers=headers, json=data, stream=True)
for chunk in response.iter_lines():
if chunk:
chunk_str = chunk.decode('utf-8')
json_start_pos = chunk_str.find('{"id"')
if json_start_pos != -1:
json_str = chunk_str[json_start_pos:]
json_data = json.loads(json_str)
for choice in json_data.get('choices', []):
delta = choice.get('delta', {})
content = delta.get('content', '')
yield {"error_code": 0, "text": content}
def get_embeddings(self, params):
# TODO: 支持embeddings
# 临时解决方案不支持embedding
print("embedding")
# print(params)
print(params)
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
# 这里的是chatglm api的模板其它API的conv_template需要定制
return conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明的助手,请根据用户的提示来完成任务",
system_message="你是智谱AI小助手请根据用户的提示来完成任务",
messages=[],
roles=["Human", "Assistant", "System"],
roles=["user", "assistant", "system"],
sep="\n###",
stop_str="###",
)

File diff suppressed because one or more lines are too long

View File

@ -10,10 +10,24 @@ from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
from langchain.llms import OpenAI
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple
from typing import (
TYPE_CHECKING,
Literal,
Optional,
Callable,
Generator,
Dict,
Any,
Awaitable,
Union,
Tuple
)
import logging
import torch
from server.minx_chat_openai import MinxChatOpenAI
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -43,7 +57,7 @@ def get_ChatOpenAI(
config = get_model_worker_config(model_name)
if model_name == "openai-api":
model_name = config.get("model_name")
ChatOpenAI._get_encoding_model = MinxChatOpenAI.get_encoding_model
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
@ -58,6 +72,7 @@ def get_ChatOpenAI(
)
return model
def get_OpenAI(
model_name: str,
temperature: float,
@ -488,16 +503,12 @@ def set_httpx_config(
no_proxy.append(host)
os.environ["NO_PROXY"] = ",".join(no_proxy)
# TODO: 简单的清除系统代理不是个好的选择影响太多。似乎修改代理服务器的bypass列表更好。
# patch requests to use custom proxies instead of system settings
def _get_proxies():
return proxies
import urllib.request
urllib.request.getproxies = _get_proxies
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:

View File

@ -6,9 +6,8 @@ import sys
from multiprocessing import Process
from datetime import datetime
from pprint import pprint
from langchain_core._api import deprecated
# 设置numexpr最大线程数默认为CPU核心数
try:
import numexpr
@ -33,15 +32,18 @@ from configs import (
HTTPX_DEFAULT_TIMEOUT,
)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_config, get_httpx_client,
get_model_worker_config, get_all_model_worker_configs,
fschat_openai_api_address, get_httpx_client, get_model_worker_config,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
from server.knowledge_base.migrate import create_tables
import argparse
from typing import Tuple, List, Dict
from typing import List, Dict
from configs import VERSION
@deprecated(
since="0.3.0",
message="模型启动功能将于 Langchain-Chatchat 0.3.x重写,支持更多模式和加速启动0.2.x中相关功能将废弃",
removal="0.3.0")
def create_controller_app(
dispatch_method: str,
log_level: str = "INFO",
@ -88,7 +90,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
for k, v in kwargs.items():
setattr(args, k, v)
if worker_class := kwargs.get("langchain_model"): #Langchian支持的模型不用做操作
if worker_class := kwargs.get("langchain_model"): # Langchian支持的模型不用做操作
from fastchat.serve.base_model_worker import app
worker = ""
# 在线模型API
@ -107,12 +109,12 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker, app, worker_id
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer = args.model_path
args.tokenizer_mode = 'auto'
args.trust_remote_code= True
args.download_dir= None
args.trust_remote_code = True
args.download_dir = None
args.load_format = 'auto'
args.dtype = 'auto'
args.seed = 0
@ -122,13 +124,13 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.block_size = 16
args.swap_space = 4 # GiB
args.gpu_memory_utilization = 0.90
args.max_num_batched_tokens = None # 一个批次中的最大令牌tokens数量这个取决于你的显卡和大模型设置设置太大显存会不够
args.max_num_batched_tokens = None # 一个批次中的最大令牌tokens数量这个取决于你的显卡和大模型设置设置太大显存会不够
args.max_num_seqs = 256
args.disable_log_stats = False
args.conv_template = None
args.limit_worker_concurrency = 5
args.no_register = False
args.num_gpus = 1 # vllm worker的切分是tensor并行这里填写显卡的数量
args.num_gpus = 1 # vllm worker的切分是tensor并行这里填写显卡的数量
args.engine_use_ray = False
args.disable_log_requests = False
@ -138,10 +140,10 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
args.quantization = None
args.max_log_len = None
args.tokenizer_revision = None
# 0.2.2 vllm需要新加的参数
args.max_paddings = 256
if args.model_path:
args.model = args.model_path
if args.num_gpus > 1:
@ -154,16 +156,16 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
engine = AsyncLLMEngine.from_engine_args(engine_args)
worker = VLLMWorker(
controller_addr = args.controller_address,
worker_addr = args.worker_address,
worker_id = worker_id,
model_path = args.model_path,
model_names = args.model_names,
limit_worker_concurrency = args.limit_worker_concurrency,
no_register = args.no_register,
llm_engine = engine,
conv_template = args.conv_template,
)
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
llm_engine=engine,
conv_template=args.conv_template,
)
sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker
sys.modules["fastchat.serve.vllm_worker"].logger.setLevel(log_level)
@ -171,7 +173,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3"
args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3"
args.max_gpu_memory = "22GiB"
args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量
@ -325,7 +327,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
with get_httpx_client() as client:
r = client.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
@ -393,8 +395,8 @@ def run_model_worker(
# add interface to release and load model
@app.post("/release")
def release_model(
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
if keep_origin:
if new_model_name:
@ -450,13 +452,13 @@ def run_webui(started_event: mp.Event = None, run_mode: str = None):
port = WEBUI_SERVER["port"]
cmd = ["streamlit", "run", "webui.py",
"--server.address", host,
"--server.port", str(port),
"--theme.base", "light",
"--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000",
]
"--server.address", host,
"--server.port", str(port),
"--theme.base", "light",
"--theme.primaryColor", "#165dff",
"--theme.secondaryBackgroundColor", "#f5f5f5",
"--theme.textColor", "#000000",
]
if run_mode == "lite":
cmd += [
"--",
@ -605,8 +607,10 @@ async def start_main_server():
Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
"""
def f(signal_received, frame):
raise KeyboardInterrupt(f"{signalname} received")
return f
# This will be inherited by the child process if it is forked (not spawned)
@ -701,8 +705,8 @@ async def start_main_server():
for model_name in args.model_name:
config = get_model_worker_config(model_name)
if (config.get("online_api")
and config.get("worker_class")
and model_name in FSCHAT_MODEL_WORKERS):
and config.get("worker_class")
and model_name in FSCHAT_MODEL_WORKERS):
e = manager.Event()
model_worker_started.append(e)
process = Process(
@ -742,12 +746,12 @@ async def start_main_server():
else:
try:
# 保证任务收到SIGINT后能够正常退出
if p:= processes.get("controller"):
if p := processes.get("controller"):
p.start()
p.name = f"{p.name} ({p.pid})"
controller_started.wait() # 等待controller启动完成
controller_started.wait() # 等待controller启动完成
if p:= processes.get("openai_api"):
if p := processes.get("openai_api"):
p.start()
p.name = f"{p.name} ({p.pid})"
@ -763,24 +767,24 @@ async def start_main_server():
for e in model_worker_started:
e.wait()
if p:= processes.get("api"):
if p := processes.get("api"):
p.start()
p.name = f"{p.name} ({p.pid})"
api_started.wait() # 等待api.py启动完成
api_started.wait() # 等待api.py启动完成
if p:= processes.get("webui"):
if p := processes.get("webui"):
p.start()
p.name = f"{p.name} ({p.pid})"
webui_started.wait() # 等待webui.py启动完成
webui_started.wait() # 等待webui.py启动完成
dump_server_info(after_start=True, args=args)
while True:
cmd = queue.get() # 收到切换模型的消息
cmd = queue.get() # 收到切换模型的消息
e = manager.Event()
if isinstance(cmd, list):
model_name, cmd, new_model_name = cmd
if cmd == "start": # 运行新模型
if cmd == "start": # 运行新模型
logger.info(f"准备启动新模型进程:{new_model_name}")
process = Process(
target=run_model_worker,
@ -831,7 +835,6 @@ async def start_main_server():
else:
logger.error(f"未找到模型进程:{model_name}")
# for process in processes.get("model_worker", {}).values():
# process.join()
# for process in processes.get("online_api", {}).values():
@ -866,10 +869,9 @@ async def start_main_server():
for p in processes.values():
logger.info("Process status: %s", p)
if __name__ == "__main__":
# 确保数据库表被创建
create_tables()
if __name__ == "__main__":
create_tables()
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
@ -879,16 +881,15 @@ if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 同步调用协程代码
loop.run_until_complete(start_main_server())
loop.run_until_complete(start_main_server())
# 服务启动后接口调用示例:
# import openai
# openai.api_key = "EMPTY" # Not support yet
# openai.api_base = "http://localhost:8888/v1"
# model = "chatglm2-6b"
# model = "chatglm3-6b"
# # create a chat completion
# completion = openai.ChatCompletion.create(

BIN
tests/samples/ocr_test.docx Normal file

Binary file not shown.

BIN
tests/samples/ocr_test.pptx Normal file

Binary file not shown.

View File

@ -14,8 +14,7 @@ from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder
# setup test knowledge base
kb_name = "test_kb_for_migrate"
test_files = {
"faq.md": str(root_path / "docs" / "faq.md"),
"install.md": str(root_path / "docs" / "install.md"),
"readme.md": str(root_path / "readme.md"),
}
@ -56,13 +55,13 @@ def test_recreate_vs():
assert doc.metadata["source"] == name
def test_increament():
def test_increment():
kb = KBServiceFactory.get_service_by_name(kb_name)
kb.clear_vs()
assert kb.list_files() == []
assert kb.list_docs() == []
folder2db([kb_name], "increament")
folder2db([kb_name], "increment")
files = kb.list_files()
print(files)

View File

@ -1,7 +1,6 @@
# 该文件封装了对api.py的请求可以被不同的webui使用
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
from typing import *
from pathlib import Path
# 此处导入的配置为发起请求如WEBUI机器上的配置主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
@ -27,7 +26,7 @@ from io import BytesIO
from server.utils import set_httpx_config, api_address, get_httpx_client
from pprint import pprint
from langchain_core._api import deprecated
set_httpx_config()
@ -36,10 +35,11 @@ class ApiRequest:
'''
api.py调用的封装同步模式,简化api调用方式
'''
def __init__(
self,
base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT,
self,
base_url: str = api_address(),
timeout: float = HTTPX_DEFAULT_TIMEOUT,
):
self.base_url = base_url
self.timeout = timeout
@ -55,12 +55,12 @@ class ApiRequest:
return self._client
def get(
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any,
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
@ -75,13 +75,13 @@ class ApiRequest:
retry -= 1
def post(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
@ -97,13 +97,13 @@ class ApiRequest:
retry -= 1
def delete(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
while retry > 0:
try:
@ -118,30 +118,33 @@ class ApiRequest:
retry -= 1
def _httpx_stream2generator(
self,
response: contextlib._GeneratorContextManager,
as_json: bool = False,
self,
response: contextlib._GeneratorContextManager,
as_json: bool = False,
):
'''
将httpx.stream返回的GeneratorContextManager转化为普通生成器
'''
async def ret_async(response, as_json):
try:
async with response as r:
async for chunk in r.aiter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
if chunk.startswith("data: "):
data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else:
data = json.loads(chunk)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
else:
# print(chunk, end="", flush=True)
yield chunk
@ -156,26 +159,28 @@ class ApiRequest:
except Exception as e:
msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
def ret_sync(response, as_json):
try:
with response as r:
for chunk in r.iter_text(None):
if not chunk: # fastchat api yield empty bytes on start and end
if not chunk: # fastchat api yield empty bytes on start and end
continue
if as_json:
try:
if chunk.startswith("data: "):
data = json.loads(chunk[6:-2])
elif chunk.startswith(":"): # skip sse comment line
continue
else:
data = json.loads(chunk)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
else:
# print(chunk, end="", flush=True)
yield chunk
@ -190,7 +195,7 @@ class ApiRequest:
except Exception as e:
msg = f"API通信遇到错误{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
yield {"code": 500, "msg": msg}
if self._use_async:
@ -199,16 +204,17 @@ class ApiRequest:
return ret_sync(response, as_json)
def _get_response_value(
self,
response: httpx.Response,
as_json: bool = False,
value_func: Callable = None,
self,
response: httpx.Response,
as_json: bool = False,
value_func: Callable = None,
):
'''
转换同步或异步请求返回的响应
`as_json`: 返回json
`value_func`: 用户可以自定义返回值该函数接受response或json
'''
def to_json(r):
try:
return r.json()
@ -216,7 +222,7 @@ class ApiRequest:
msg = "API未能返回正确的JSON。" + str(e)
if log_verbose:
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
exc_info=e if log_verbose else None)
return {"code": 500, "msg": msg, "data": None}
if value_func is None:
@ -246,10 +252,10 @@ class ApiRequest:
return self._get_response_value(response, as_json=True, value_func=lambda r: r["data"])
def get_prompt_template(
self,
type: str = "llm_chat",
name: str = "default",
**kwargs,
self,
type: str = "llm_chat",
name: str = "default",
**kwargs,
) -> str:
data = {
"type": type,
@ -293,15 +299,19 @@ class ApiRequest:
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="自定义Agent问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0")
def agent_chat(
self,
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
self,
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/agent_chat 接口
@ -320,20 +330,20 @@ class ApiRequest:
# pprint(data)
response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response)
return self._httpx_stream2generator(response, as_json=True)
def knowledge_base_chat(
self,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
self,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/knowledge_base_chat接口
@ -362,28 +372,29 @@ class ApiRequest:
return self._httpx_stream2generator(response, as_json=True)
def upload_temp_docs(
self,
files: List[Union[str, Path, bytes]],
knowledge_id: str = None,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
self,
files: List[Union[str, Path, bytes]],
knowledge_id: str = None,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
):
'''
对应api.py/knowledge_base/upload_tmep_docs接口
'''
def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1]
return filename, file
files = [convert_file(file) for file in files]
data={
data = {
"knowledge_id": knowledge_id,
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
@ -398,17 +409,17 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def file_chat(
self,
query: str,
knowledge_id: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
self,
query: str,
knowledge_id: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
):
'''
对应api.py/chat/file_chat接口
@ -436,18 +447,23 @@ class ApiRequest:
)
return self._httpx_stream2generator(response, as_json=True)
@deprecated(
since="0.3.0",
message="搜索引擎问答将于 Langchain-Chatchat 0.3.x重写, 0.2.x中相关功能将废弃",
removal="0.3.0"
)
def search_engine_chat(
self,
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
split_result: bool = False,
self,
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
max_tokens: int = None,
prompt_name: str = "default",
split_result: bool = False,
):
'''
对应api.py/chat/search_engine_chat接口
@ -478,7 +494,7 @@ class ApiRequest:
# 知识库相关操作
def list_knowledge_bases(
self,
self,
):
'''
对应api.py/knowledge_base/list_knowledge_bases接口
@ -489,10 +505,10 @@ class ApiRequest:
value_func=lambda r: r.get("data", []))
def create_knowledge_base(
self,
knowledge_base_name: str,
vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
self,
knowledge_base_name: str,
vector_store_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
):
'''
对应api.py/knowledge_base/create_knowledge_base接口
@ -510,8 +526,8 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def delete_knowledge_base(
self,
knowledge_base_name: str,
self,
knowledge_base_name: str,
):
'''
对应api.py/knowledge_base/delete_knowledge_base接口
@ -523,8 +539,8 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def list_kb_docs(
self,
knowledge_base_name: str,
self,
knowledge_base_name: str,
):
'''
对应api.py/knowledge_base/list_files接口
@ -538,13 +554,13 @@ class ApiRequest:
value_func=lambda r: r.get("data", []))
def search_kb_docs(
self,
knowledge_base_name: str,
query: str = "",
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
file_name: str = "",
metadata: dict = {},
self,
knowledge_base_name: str,
query: str = "",
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
file_name: str = "",
metadata: dict = {},
) -> List:
'''
对应api.py/knowledge_base/search_docs接口
@ -565,9 +581,9 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def update_docs_by_id(
self,
knowledge_base_name: str,
docs: Dict[str, Dict],
self,
knowledge_base_name: str,
docs: Dict[str, Dict],
) -> bool:
'''
对应api.py/knowledge_base/update_docs_by_id接口
@ -583,32 +599,33 @@ class ApiRequest:
return self._get_response_value(response)
def upload_kb_docs(
self,
files: List[Union[str, Path, bytes]],
knowledge_base_name: str,
override: bool = False,
to_vector_store: bool = True,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
self,
files: List[Union[str, Path, bytes]],
knowledge_base_name: str,
override: bool = False,
to_vector_store: bool = True,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
):
'''
对应api.py/knowledge_base/upload_docs接口
'''
def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or os.path.split(file.name)[-1]
return filename, file
files = [convert_file(file) for file in files]
data={
data = {
"knowledge_base_name": knowledge_base_name,
"override": override,
"to_vector_store": to_vector_store,
@ -629,11 +646,11 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def delete_kb_docs(
self,
knowledge_base_name: str,
file_names: List[str],
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
self,
knowledge_base_name: str,
file_names: List[str],
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
):
'''
对应api.py/knowledge_base/delete_docs接口
@ -651,8 +668,7 @@ class ApiRequest:
)
return self._get_response_value(response, as_json=True)
def update_kb_info(self,knowledge_base_name,kb_info):
def update_kb_info(self, knowledge_base_name, kb_info):
'''
对应api.py/knowledge_base/update_info接口
'''
@ -668,15 +684,15 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def update_kb_docs(
self,
knowledge_base_name: str,
file_names: List[str],
override_custom_docs: bool = False,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
self,
knowledge_base_name: str,
file_names: List[str],
override_custom_docs: bool = False,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
):
'''
对应api.py/knowledge_base/update_docs接口
@ -702,14 +718,14 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def recreate_vector_store(
self,
knowledge_base_name: str,
allow_empty_kb: bool = True,
vs_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
self,
knowledge_base_name: str,
allow_empty_kb: bool = True,
vs_type: str = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size=CHUNK_SIZE,
chunk_overlap=OVERLAP_SIZE,
zh_title_enhance=ZH_TITLE_ENHANCE,
):
'''
对应api.py/knowledge_base/recreate_vector_store接口
@ -734,8 +750,8 @@ class ApiRequest:
# LLM模型相关操作
def list_running_models(
self,
controller_address: str = None,
self,
controller_address: str = None,
):
'''
获取Fastchat中正运行的模型列表
@ -751,8 +767,7 @@ class ApiRequest:
"/llm_model/list_running_models",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", []))
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", []))
def get_default_llm_model(self, local_first: bool = True) -> Tuple[str, bool]:
'''
@ -760,6 +775,7 @@ class ApiRequest:
local_first=True 优先返回运行中的本地模型否则优先按LLM_MODELS配置顺序返回
返回类型为model_name, is_local_model
'''
def ret_sync():
running_models = self.list_running_models()
if not running_models:
@ -776,7 +792,7 @@ class ApiRequest:
model = m
break
if not model: # LLM_MODELS中配置的模型都不在running_models里
if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
@ -797,7 +813,7 @@ class ApiRequest:
model = m
break
if not model: # LLM_MODELS中配置的模型都不在running_models里
if not model: # LLM_MODELS中配置的模型都不在running_models里
model = list(running_models)[0]
is_local = not running_models[model].get("online_api")
return model, is_local
@ -808,8 +824,8 @@ class ApiRequest:
return ret_sync()
def list_config_models(
self,
types: List[str] = ["local", "online"],
self,
types: List[str] = ["local", "online"],
) -> Dict[str, Dict]:
'''
获取服务器configs中配置的模型列表返回形式为{"type": {model_name: config}, ...}
@ -821,23 +837,23 @@ class ApiRequest:
"/llm_model/list_config_models",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def get_model_config(
self,
model_name: str = None,
self,
model_name: str = None,
) -> Dict:
'''
获取服务器上模型配置
'''
data={
data = {
"model_name": model_name,
}
response = self.post(
"/llm_model/get_model_config",
json=data,
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def list_search_engines(self) -> List[str]:
'''
@ -846,12 +862,12 @@ class ApiRequest:
response = self.post(
"/server/list_search_engines",
)
return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {}))
return self._get_response_value(response, as_json=True, value_func=lambda r: r.get("data", {}))
def stop_llm_model(
self,
model_name: str,
controller_address: str = None,
self,
model_name: str,
controller_address: str = None,
):
'''
停止某个LLM模型
@ -869,10 +885,10 @@ class ApiRequest:
return self._get_response_value(response, as_json=True)
def change_llm_model(
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
self,
model_name: str,
new_model_name: str,
controller_address: str = None,
):
'''
向fastchat controller请求切换LLM模型
@ -955,10 +971,10 @@ class ApiRequest:
return ret_sync()
def embed_texts(
self,
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
self,
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> List[List[float]]:
'''
对文本进行向量化可选模型包括本地 embed_models 和支持 embeddings 的在线模型
@ -975,10 +991,10 @@ class ApiRequest:
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
def chat_feedback(
self,
message_id: str,
score: int,
reason: str = "",
self,
message_id: str,
score: int,
reason: str = "",
) -> int:
'''
反馈对话评价
@ -1015,9 +1031,9 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
return error message if error occured when requests API
'''
if (isinstance(data, dict)
and key in data
and "code" in data
and data["code"] == 200):
and key in data
and "code" in data
and data["code"] == 200):
return data[key]
return ""