发版:v0.2.5 (#1620)

* 优化configs (#1474)

* remove llm_model_dict

* optimize configs

* fix get_model_path

* 更改一些默认参数,添加千帆的默认配置

* Update server_config.py.example

* fix merge conflict for #1474 (#1494)

* 修复ChatGPT api_base_url错误;用户可以在model_config在线模型配置中覆盖默认的api_base_url (#1496)

* 优化LLM模型列表获取、切换的逻辑: (#1497)

1、更准确的获取未运行的可用模型
2、优化WEBUI模型列表显示与切换的控制逻辑

* 更新migrate.py和init_database.py,加强知识库迁移工具: (#1498)

1. 添加--update-in-db参数,按照数据库信息,从本地文件更新向量库
2. 添加--increament参数,根据本地文件增量更新向量库
3. 添加--prune-db参数,删除本地文件后,自动清理相关的向量库
4. 添加--prune-folder参数,根据数据库信息,清理无用的本地文件
5. 取消--update-info-only参数。数据库中存储了向量库信息,该操作意义不大
6. 添加--kb-name参数,所有操作支持指定操作的知识库,不指定则为所有本地知识库
7. 添加知识库迁移的测试用例
8. 删除milvus_kb_service的save_vector_store方法

* feat: support volc fangzhou

* 使火山方舟正常工作,添加错误处理和测试用例

* feat: support volc fangzhou (#1501)

* feat: support volc fangzhou

---------

Co-authored-by: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com>
Co-authored-by: liqiankun.1111 <liqiankun.1111@bytedance.com>

* 第一版初步agent实现 (#1503)

* 第一版初步agent实现

* 增加steaming参数

* 修改了weather.py

---------

Co-authored-by: zR <zRzRzRzRzRzRzR>

* 添加configs/prompt_config.py,允许用户自定义prompt模板: (#1504)

1、 默认包含2个模板,分别用于LLM对话,知识库和搜索引擎对话
2、 server/utils.py提供函数get_prompt_template,获取指定的prompt模板内容(支持热加载)
3、 api.py中chat/knowledge_base_chat/search_engine_chat接口支持prompt_name参数

* 增加其它模型的参数适配

* 增加传入矢量名称加载

* 1. 搜索引擎问答支持历史记录;
2. 修复知识库问答历史记录传参错误:用户输入被传入history,问题出在webui中重复获取历史消息,api知识库对话接口并无问题。

* langchain日志开关

* move wrap_done & get_ChatOpenAI from server.chat.utils to server.utils (#1506)

* 修复faiss_pool知识库缓存key错误 (#1507)

* fix ReadMe anchor link (#1500)

* fix : Duplicate variable and function name (#1509)

Co-authored-by: Jim <zhangpengyi@taijihuabao.com>

* Update README.md

* fix #1519: streamlit-chatbox旧版BUG,但新版有兼容问题,先在webui中作处理,并限定chatbox版本 (#1525)

close #1519

* 【功能新增】在线 LLM 模型支持阿里云通义千问 (#1534)

* feat: add qwen-api

* 使Qwen API支持temperature参数;添加测试用例

* 将online-api的sdk列为可选依赖

---------

Co-authored-by: liunux4odoo <liunux@qq.com>

* 处理序列化至磁盘的逻辑

* remove depends on volcengine

* update kb_doc_api: use Form instead of Body when upload file

* 将所有httpx请求改为使用Client,提高效率,方便以后设置代理等。 (#1554)

将所有httpx请求改为使用Client,提高效率,方便以后设置代理等。

将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)

* update QR code

* update readme_en,readme,requirements_api,requirements,model_config.py.example:测试baichuan2-7b;更新相关文档

* 新增特性:1.支持vllm推理加速框架;2. 更新支持模型列表

* 更新文件:1. startup,model_config.py.example,serve_config.py.example,FAQ

* 1. debug vllm加速框架完毕;2. 修改requirements,requirements_api对vllm的依赖;3.注释掉serve_config中baichuan-7b的device为cpu的配置

* 1. 更新congif中关于vllm后端相关说明;2. 更新requirements,requirements_api;

* 增加了仅限GPT4的agent功能,陆续补充,中文版readme已写 (#1611)

* Dev (#1613)

* 增加了仅限GPT4的agent功能,陆续补充,中文版readme已写

* issue提到的一个bug

* 温度最小改成0,但是不应该支持负数

* 修改了最小的温度

* fix: set vllm based on platform to avoid error on windows

* fix: langchain warnings for import from root

* 修复webui中重建知识库以及对话界面UI错误 (#1615)

* 修复bug:webui点重建知识库时,如果存在不支持的文件会导致整个接口错误;migrate中没有导入CHUNK_SIZE

* 修复:webui对话界面的expander一直为running状态;简化历史消息获取方法

* 根据官方文档,添加对英文版的bge embedding的指示模板 (#1585)

Co-authored-by: zR <2448370773@qq.com>

* Dev (#1618)

* 增加了仅限GPT4的agent功能,陆续补充,中文版readme已写

* issue提到的一个bug

* 温度最小改成0,但是不应该支持负数

* 修改了最小的温度

* 增加了部分Agent支持和修改了启动文件的部分bug

* 修改了GPU数量配置文件

* 1

1

* 修复配置文件错误

* 更新readme,稳定测试

* 更改readme 0928 (#1619)

* 增加了仅限GPT4的agent功能,陆续补充,中文版readme已写

* issue提到的一个bug

* 温度最小改成0,但是不应该支持负数

* 修改了最小的温度

* 增加了部分Agent支持和修改了启动文件的部分bug

* 修改了GPU数量配置文件

* 1

1

* 修复配置文件错误

* 更新readme,稳定测试

* 更新readme

* fix readme

* 处理序列化至磁盘的逻辑

* update version number to v0.2.5

---------

Co-authored-by: qiankunli <qiankun.li@qq.com>
Co-authored-by: liqiankun.1111 <liqiankun.1111@bytedance.com>
Co-authored-by: zR <2448370773@qq.com>
Co-authored-by: glide-the <2533736852@qq.com>
Co-authored-by: Water Zheng <1499383852@qq.com>
Co-authored-by: Jim Zhang <dividi_z@163.com>
Co-authored-by: Jim <zhangpengyi@taijihuabao.com>
Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
Co-authored-by: Leego <leegodev@hotmail.com>
Co-authored-by: hzg0601 <hzg0601@163.com>
Co-authored-by: WilliamChen-luckbob <58684828+WilliamChen-luckbob@users.noreply.github.com>
This commit is contained in:
liunux4odoo 2023-09-28 23:30:21 +08:00 committed by GitHub
parent db169f628c
commit ba8d0f8e17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
69 changed files with 2941 additions and 813 deletions

View File

@ -57,6 +57,25 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
---
## 环境最低要求
想顺利运行本代码,请按照以下的最低要求进行配置:
+ Python版本: >= 3.8.5, < 3.11
+ Cuda版本: >= 11.7, 且能顺利安装Python
如果想要顺利在GPU运行本地模型(int4版本),你至少需要以下的硬件配置:
+ chatglm2-6b & LLaMA-7B 最低显存要求: 7GB 推荐显卡: RTX 3060, RTX 2060
+ LLaMA-13B 最低显存要求: 11GB 推荐显卡: RTX 2060 12GB, RTX3060 12GB, RTX3080, RTXA2000
+ Qwen-14B-Chat 最低显存要求: 13GB 推荐显卡: RTX 3090
+ LLaMA-30B 最低显存要求: 22GB 推荐显卡RTX A5000,RTX 3090,RTX 4090,RTX 6000,Tesla V100,RTX Tesla P40
+ LLaMA-65B 最低显存要求: 40GB 推荐显卡A100,A40,A6000
如果是int8 则显存x1.5 fp16 x2.5的要求
使用fp16 推理Qwen-7B-Chat 模型 则需要使用16GB显存。
以上仅为估算实际情况以nvidia-smi占用为准。
## 变更日志
参见 [版本更新日志](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)。
@ -112,27 +131,29 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
- [Qwen/Qwen-7B-Chat/Qwen-14B-Chat](https://huggingface.co/Qwen/)
- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta)
- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and others
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
- [all models of OpenOrca](https://huggingface.co/Open-Orca)
- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2)
- [VMware&#39;s OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
- [baichuan2-7b/baichuan2-13b](https://huggingface.co/baichuan-inc)
- 任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)
- 在以上模型基础上训练的任何 [Peft](https://github.com/huggingface/peft) 适配器。为了激活,模型路径中必须有 `peft` 。注意如果加载多个peft模型你可以通过在任何模型工作器中设置环境变量 `PEFT_SHARE_BASE_WEIGHTS=true` 来使它们共享基础模型的权重。
以上模型支持列表可能随 [FastChat](https://github.com/lm-sys/FastChat) 更新而持续更新,可参考 [FastChat 已支持模型列表](https://github.com/lm-sys/FastChat/blob/main/docs/model_support.md)。
除本地模型外,本项目也支持直接接入 OpenAI API、智谱AI等在线模型具体设置可参考 `configs/model_configs.py.example` 中的 `llm_model_dict` 的配置信息。
在线 LLM 模型目前已支持:
在线 LLM 模型目前已支持:
- [ChatGPT](https://api.openai.com)
- [智谱AI](http://open.bigmodel.cn)
- [MiniMax](https://api.minimax.chat)
- [讯飞星火](https://xinghuo.xfyun.cn)
- [百度千帆](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
- [阿里云通义千问](https://dashscope.aliyun.com/)
项目中默认使用的 LLM 类型为 `THUDM/chatglm2-6b`,如需使用其他 LLM 类型,请在 [configs/model_config.py] 中对 `llm_model_dict``LLM_MODEL` 进行修改。
@ -157,9 +178,11 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-large-zh)
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
项目中默认使用的 Embedding 类型为 `moka-ai/m3e-base`,如需使用其他 Embedding 类型,请在 [configs/model_config.py] 中对 `embedding_model_dict``EMBEDDING_MODEL` 进行修改。
项目中默认使用的 Embedding 类型为 `sensenova/piccolo-base-zh`,如需使用其他 Embedding 类型,请在 [configs/model_config.py] 中对 `embedding_model_dict``EMBEDDING_MODEL` 进行修改。
---
@ -187,15 +210,27 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
关于如何使用自定义分词器和贡献自己的分词器,可以参考[Text Splitter 贡献说明](docs/splitter.md)。
## Agent生态
### 基础的Agent
在本版本中我们实现了一个简单的基于OpenAI的React的Agent模型目前经过我们测试仅有以下两个模型支持
+ OpenAI GPT4
+ ChatGLM2-130B
目前版本的Agent仍然需要对提示词进行大量调试调试位置
### 构建自己的Agent工具
详见 [自定义Agent说明](docs/自定义Agent.md)
## Docker 部署
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5)`
```shell
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5
```
- 该版本镜像大小 `35.3GB`,使用 `v0.2.3`,以 `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` 为基础镜像
- 该版本镜像大小 `35.3GB`,使用 `v0.2.5`,以 `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` 为基础镜像
- 该版本内置两个 `embedding` 模型:`m3e-large``text2vec-bge-large-chinese`,默认启用后者,内置 `chatglm2-6b-32k`
- 该版本目标为方便一键部署使用请确保您已经在Linux发行版上安装了NVIDIA驱动程序
- 请注意您不需要在主机系统上安装CUDA工具包但需要安装 `NVIDIA Driver` 以及 `NVIDIA Container Toolkit`,请参考[安装指南](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
@ -391,8 +426,8 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
- [X] .csv
- [ ] .xlsx
- [ ] 分词及召回
- [ ] 接入不同类型 TextSplitter
- [ ] 优化依据中文标点符号设计的 ChineseTextSplitter
- [X] 接入不同类型 TextSplitter
- [X] 优化依据中文标点符号设计的 ChineseTextSplitter
- [ ] 重新实现上下文拼接召回
- [ ] 本地网页接入
- [ ] SQL 接入
@ -400,13 +435,17 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
- [X] 搜索引擎接入
- [X] Bing 搜索
- [X] DuckDuckGo 搜索
- [ ] Agent 实现
- [X] Agent 实现
- [X] 基础React形式的Agent实现包括调用计算器等
- [X] Langchain 自带的Agent实现和调用
- [ ] 更多模型的Agent支持
- [ ] 更多工具
- [X] LLM 模型接入
- [X] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm
- [ ] 支持 ChatGLM API 等 LLM API 的接入
- [X] 支持 ChatGLM API 等 LLM API 的接入
- [X] Embedding 模型接入
- [X] 支持调用 HuggingFace 中各开源 Emebdding 模型
- [ ] 支持 OpenAI Embedding API 等 Embedding API 的接入
- [X] 支持 OpenAI Embedding API 等 Embedding API 的接入
- [X] 基于 FastAPI 的 API 方式调用
- [X] Web UI
- [X] 基于 Streamlit 的 Web UI
@ -417,4 +456,12 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
<img src="img/qr_code_64.jpg" alt="二维码" width="300" height="300" />
🎉 langchain-ChatGLM 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
🎉 langchain-Chatchat 项目微信交流群,如果你也对本项目感兴趣,欢迎加入群聊参与讨论交流。
## 关注我们
<img src="img/official_account.png" alt="图片" width="900" height="300" />
🎉 langchain-Chatchat 项目官方公众号,欢迎扫码关注。

View File

@ -56,6 +56,25 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
---
## Environment Minimum Requirements
To run this code smoothly, please configure it according to the following minimum requirements:
+ Python version: >= 3.8.5, < 3.11
+ Cuda version: >= 11.7, with Python installed.
If you want to run the native model (int4 version) on the GPU without problems, you need at least the following hardware configuration.
+ chatglm2-6b & LLaMA-7B Minimum RAM requirement: 7GB Recommended graphics cards: RTX 3060, RTX 2060
+ LLaMA-13B Minimum graphics memory requirement: 11GB Recommended cards: RTX 2060 12GB, RTX3060 12GB, RTX3080, RTXA2000
+ Qwen-14B-Chat Minimum memory requirement: 13GB Recommended graphics card: RTX 3090
+ LLaMA-30B Minimum Memory Requirement: 22GB Recommended Cards: RTX A5000,RTX 3090,RTX 4090,RTX 6000,Tesla V100,RTX Tesla P40
+ Minimum memory requirement for LLaMA-65B: 40GB Recommended cards: A100,A40,A6000
If int8 then memory x1.5 fp16 x2.5 requirement.
For example: using fp16 to reason about the Qwen-7B-Chat model requires 16GB of video memory.
The above is only an estimate, the actual situation is based on nvidia-smi occupancy.
## Change Log
plese refer to [version change log](https://github.com/imClumsyPanda/langchain-ChatGLM/releases)
@ -105,18 +124,31 @@ The project use [FastChat](https://github.com/lm-sys/FastChat) to provide the AP
- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0)
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
- [Qwen/Qwen-7B-Chat/Qwen-14B-Chat](https://huggingface.co/Qwen/)
- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta)
- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) and other models of FlagAlpha
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
- [all models of OpenOrca](https://huggingface.co/Open-Orca)
- [Spicyboros](https://huggingface.co/jondurbin/spicyboros-7b-2.2?not-for-all-audiences=true) + [airoboros 2.2](https://huggingface.co/jondurbin/airoboros-l2-13b-2.2)
- [baichuan2-7b/baichuan2-13b](https://huggingface.co/baichuan-inc)
- [VMware&#39;s OpenLLaMa OpenInstruct](https://huggingface.co/VMware/open-llama-7b-open-instruct)
* Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)(任何 [EleutherAI](https://huggingface.co/EleutherAI) 的 pythia 模型,如 [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b))
* Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a model above. To activate, must have `peft` in the model path. Note: If loading multiple peft models, you can have them share the base model weights by setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model worker.
Please refer to `llm_model_dict` in `configs.model_configs.py.example` to invoke OpenAI API.
The above model support list may be updated continuously as [FastChat](https://github.com/lm-sys/FastChat) is updated, see [FastChat Supported Models List](https://github.com/lm-sys/FastChat/blob/main /docs/model_support.md).
In addition to local models, this project also supports direct access to online models such as OpenAI API, Wisdom Spectrum AI, etc. For specific settings, please refer to the configuration information of `llm_model_dict` in `configs/model_configs.py.example`.
Online LLM models are currently supported:
- [ChatGPT](https://api.openai.com)
- [Smart Spectrum AI](http://open.bigmodel.cn)
- [MiniMax](https://api.minimax.chat)
- [Xunfei Starfire](https://xinghuo.xfyun.cn)
- [Baidu Qianfan](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
- [Aliyun Tongyi Qianqian](https://dashscope.aliyun.com/)
The default LLM type used in the project is `THUDM/chatglm2-6b`, if you need to use other LLM types, please modify `llm_model_dict` and `LLM_MODEL` in [configs/model_config.py].
### Supported Embedding models
@ -129,6 +161,8 @@ Following models are tested by developers with Embedding class of [HuggingFace](
- [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh)
- [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh)
- [BAAI/bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct)
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
- [sensenova/piccolo-large-zh](https://huggingface.co/sensenova/piccolo-large-zh)
- [shibing624/text2vec-base-chinese-sentence](https://huggingface.co/shibing624/text2vec-base-chinese-sentence)
- [shibing624/text2vec-base-chinese-paraphrase](https://huggingface.co/shibing624/text2vec-base-chinese-paraphrase)
- [shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
@ -137,16 +171,24 @@ Following models are tested by developers with Embedding class of [HuggingFace](
- [GanymedeNil/text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese)
- [nghuyong/ernie-3.0-nano-zh](https://huggingface.co/nghuyong/ernie-3.0-nano-zh)
- [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh)
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-base-zh)
- [sensenova/piccolo-base-zh](https://huggingface.co/sensenova/piccolo-large-zh)
- [OpenAI/text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings)
The default Embedding type used in the project is `sensenova/piccolo-base-zh`, if you want to use other Embedding types, please modify `embedding_model_dict` and `embedding_model_dict` and `embedding_model_dict` in [configs/model_config.py]. MODEL` in [configs/model_config.py].
### Build your own Agent tool!
See [Custom Agent Instructions](docs/自定义Agent.md) for details.
---
## Docker Deployment
🐳 Docker image path: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0)`
🐳 Docker image path: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5)`
```shell
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.0
docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.5
```
- The image size of this version is `33.9GB`, using `v0.2.0`, with `nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04` as the base image
@ -328,9 +370,9 @@ Please refer to [FAQ](docs/FAQ.md)
- [ ] Structured documents
- [X] .csv
- [ ] .xlsx
- [ ] TextSplitter and Retriever
- [x] multiple TextSplitter
- [x] ChineseTextSplitter
- [] TextSplitter and Retriever
- [X] multiple TextSplitter
- [X] ChineseTextSplitter
- [ ] Reconstructed Context Retriever
- [ ] Webpage
- [ ] SQL
@ -338,7 +380,11 @@ Please refer to [FAQ](docs/FAQ.md)
- [X] Search Engines
- [X] Bing
- [X] DuckDuckGo
- [ ] Agent
- [X] Agent
- [X] Agent implementation in the form of basic React, including calls to calculators, etc.
- [X] Langchain's own Agent implementation and calls
- [ ] More Agent support for models
- [ ] More tools
- [X] LLM Models
- [X] [FastChat](https://github.com/lm-sys/fastchat) -based LLM Models
- [ ] Mutiply Remote LLM API
@ -348,3 +394,16 @@ Please refer to [FAQ](docs/FAQ.md)
- [X] FastAPI-based API
- [X] Web UI
- [X] Streamlit -based Web UI
---
## Wechat Group
<img src="img/qr_code_64.jpg" alt="QR Code" width="300" height="300" />
🎉 langchain-Chatchat project WeChat exchange group, if you are also interested in this project, welcome to join the group chat to participate in the discussion and exchange.
## Follow us
<img src="img/official_account.png" alt="image" width="900" height="300" />
🎉 langchain-Chatchat project official public number, welcome to scan the code to follow.

View File

@ -1,19 +1,12 @@
from langchain.chat_models import ChatOpenAI
from configs.model_config import llm_model_dict, LLM_MODEL
from langchain import LLMChain
from server.utils import get_ChatOpenAI
from configs.model_config import LLM_MODEL, TEMPERATURE
from langchain.chains import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
model = ChatOpenAI(
streaming=True,
verbose=True,
# callbacks=[callback],
openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
model = get_ChatOpenAI(model_name=LLM_MODEL, temperature=TEMPERATURE)
human_prompt = "{input}"

View File

@ -1,4 +1,8 @@
from .basic_config import *
from .model_config import *
from .kb_config import *
from .server_config import *
from .prompt_config import *
VERSION = "v0.2.4"
VERSION = "v0.2.5"

View File

@ -0,0 +1,22 @@
import logging
import os
import langchain
# 是否显示详细日志
log_verbose = False
langchain.verbose = False
# 通常情况下不需要更改以下内容
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH)

View File

@ -0,0 +1,99 @@
import os
# 默认向量库类型。可选faiss, milvus, pg.
DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量针对FAISS
CACHED_VS_NUM = 1
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 50
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
# Bing 搜索必备变量
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
# 具体申请方式请见
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
# 使用python创建bing api 搜索实例详见:
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 注意不是bing Webmaster Tools的api key
# 此外如果是在服务器上报Failed to establish a new connection: [Errno 110] Connection timed out
# 是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG
BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
# 可选向量库类型及对应配置
kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", ## 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "gpt2",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# TEXT_SPLITTER 名称
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"

View File

@ -1,63 +1,115 @@
import os
import logging
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 是否显示详细日志
log_verbose = False
# 在以下字典中修改属性值以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径
embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec-base": "shibing624/text2vec-base-chinese",
"text2vec": "GanymedeNil/text2vec-large-chinese",
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
"text2vec-multilingual": "shibing624/text2vec-base-multilingual",
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
"m3e-small": "moka-ai/m3e-small",
"m3e-base": "moka-ai/m3e-base",
"m3e-large": "moka-ai/m3e-large",
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"text-embedding-ada-002": os.environ.get("OPENAI_API_KEY")
# 可以指定一个绝对路径统一存放所有的Embedding和LLM模型。
# 每个模型可以是一个单独的目录,也可以是某个目录下的二级子目录
MODEL_ROOT_PATH = ""
# 在以下字典中修改属性值以指定本地embedding模型存储位置。支持3种设置方法
# 1、将对应的值修改为模型绝对路径
# 2、不修改此处的值以 text2vec 为例):
# 2.1 如果{MODEL_ROOT_PATH}下存在如下任一子目录:
# - text2vec
# - GanymedeNil/text2vec-large-chinese
# - text2vec-large-chinese
# 2.2 如果以上本地路径不存在则使用huggingface模型
MODEL_PATH = {
"embed_model": {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec-base": "shibing624/text2vec-base-chinese",
"text2vec": "GanymedeNil/text2vec-large-chinese",
"text2vec-paraphrase": "shibing624/text2vec-base-chinese-paraphrase",
"text2vec-sentence": "shibing624/text2vec-base-chinese-sentence",
"text2vec-multilingual": "shibing624/text2vec-base-multilingual",
"text2vec-bge-large-chinese": "shibing624/text2vec-bge-large-chinese",
"m3e-small": "moka-ai/m3e-small",
"m3e-base": "moka-ai/m3e-base",
"m3e-large": "moka-ai/m3e-large",
"bge-small-zh": "BAAI/bge-small-zh",
"bge-base-zh": "BAAI/bge-base-zh",
"bge-large-zh": "BAAI/bge-large-zh",
"bge-large-zh-noinstruct": "BAAI/bge-large-zh-noinstruct",
"piccolo-base-zh": "sensenova/piccolo-base-zh",
"piccolo-large-zh": "sensenova/piccolo-large-zh",
"text-embedding-ada-002": "your OPENAI_API_KEY",
},
# TODO: add all supported llm models
"llm_model": {
# 以下部分模型并未完全测试仅根据fastchat和vllm模型的模型列表推定支持
"chatglm-6b": "THUDM/chatglm-6b",
"chatglm2-6b": "THUDM/chatglm2-6b",
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
"baichuan2-13b": "baichuan-inc/Baichuan-13B-Chat",
"baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat",
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
'baichuan-13b-chat':'baichuan-inc/Baichuan-13B-Chat',
"aquila-7b":"BAAI/Aquila-7B",
"aquilachat-7b":"BAAI/AquilaChat-7B",
"internlm-7b":"internlm/internlm-7b",
"internlm-chat-7b":"internlm/internlm-chat-7b",
"falcon-7b":"tiiuae/falcon-7b",
"falcon-40b":"tiiuae/falcon-40b",
"falcon-rw-7b":"tiiuae/falcon-rw-7b",
"gpt2":"gpt2",
"gpt2-xl":"gpt2-xl",
"gpt-j-6b":"EleutherAI/gpt-j-6b",
"gpt4all-j":"nomic-ai/gpt4all-j",
"gpt-neox-20b":"EleutherAI/gpt-neox-20b",
"pythia-12b":"EleutherAI/pythia-12b",
"oasst-sft-4-pythia-12b-epoch-3.5":"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b":"databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b":"stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf":"meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf":"meta-llama/Llama-2-70b-hf",
"open_llama_13b":"openlm-research/open_llama_13b",
"vicuna-13b-v1.3":"lmsys/vicuna-13b-v1.3",
"koala":"young-geng/koala",
"mpt-7b":"mosaicml/mpt-7b",
"mpt-7b-storywriter":"mosaicml/mpt-7b-storywriter",
"mpt-30b":"mosaicml/mpt-30b",
"opt-66b":"facebook/opt-66b",
"opt-iml-max-30b":"facebook/opt-iml-max-30b",
"Qwen-7B":"Qwen/Qwen-7B",
"Qwen-14B":"Qwen/Qwen-14B",
"Qwen-7B-Chat":"Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat":"Qwen/Qwen-14B-Chat",
},
}
# 选用的 Embedding 名称
EMBEDDING_MODEL = "m3e-base"
EMBEDDING_MODEL = "m3e-base" # 可以尝试最新的嵌入式sota模型piccolo-large-zh
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto"
llm_model_dict = {
"chatglm-6b": {
"local_model_path": "THUDM/chatglm-6b",
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
"api_key": "EMPTY"
},
# LLM 名称
LLM_MODEL = "chatglm2-6b"
"chatglm2-6b": {
"local_model_path": "THUDM/chatglm2-6b",
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
"api_key": "EMPTY"
},
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
LLM_DEVICE = "auto"
"chatglm2-6b-32k": {
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
"api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
"api_key": "EMPTY"
},
# 历史对话轮数
HISTORY_LEN = 3
# LLM通用对话参数
TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
ONLINE_LLM_MODEL = {
# 调用chatgpt时如果报出 urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='api.openai.com', port=443):
# Max retries exceeded with url: /v1/chat/completions
# 则需要将urllib3版本修改为1.25.11
@ -75,28 +127,25 @@ llm_model_dict = {
# 比如: "openai_proxy": 'http://127.0.0.1:4780'
"gpt-3.5-turbo": {
"api_base_url": "https://api.openai.com/v1",
"api_key": "",
"openai_proxy": ""
"api_key": "your OPENAI_API_KEY",
"openai_proxy": "your OPENAI_PROXY",
},
# 线上模型。当前支持智谱AI。
# 如果没有设置有效的local_model_path则认为是在线模型API。
# 请在server_config中为每个在线API设置不同的端口
# 线上模型。请在server_config中为每个在线API设置不同的端口
# 具体注册及api key获取请前往 http://open.bigmodel.cn
"zhipu-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"api_key": "",
"provider": "ChatGLMWorker",
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
"provider": "ChatGLMWorker",
},
# 具体注册及api key获取请前往 https://api.minimax.chat/
"minimax-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"group_id": "",
"api_key": "",
"is_pro": False,
"provider": "MiniMaxWorker",
},
# 具体注册及api key获取请前往 https://xinghuo.xfyun.cn/
"xinghuo-api": {
"api_base_url": "http://127.0.0.1:8888/v1",
"APPID": "",
"APISecret": "",
"api_key": "",
@ -105,140 +154,77 @@ llm_model_dict = {
},
# 百度千帆 API申请方式请参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/4lilb2lpf
"qianfan-api": {
"version": "ernie-bot", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" 更多的见文档模型支持列表中千帆部分。
"version_url": "", # 可以不填写version直接填写在千帆申请模型发布的API地址
"api_base_url": "http://127.0.0.1:8888/v1",
"version": "ernie-bot-turbo", # 当前支持 "ernie-bot" 或 "ernie-bot-turbo" 更多的见官方文档。
"version_url": "", # 也可以不填写version直接填写在千帆申请模型发布的API地址
"api_key": "",
"secret_key": "",
"provider": "QianFanWorker",
}
}
# LLM 名称
LLM_MODEL = "chatglm2-6b"
# 历史对话轮数
HISTORY_LEN = 3
# LLM通用对话参数
TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
LLM_DEVICE = "auto"
# TextSplitter
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "",
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
# 火山方舟 API文档参考 https://www.volcengine.com/docs/82379
"fangzhou-api": {
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model" 更多的见文档模型支持列表中方舟部分。
"version_url": "", # 可以不填写version直接填写在方舟申请模型发布的API地址
"api_key": "",
"secret_key": "",
"provider": "FangZhouWorker",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
# 阿里云通义千问 API文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
"qwen-api": {
"version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus"
"api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建
"provider": "QwenWorker",
},
}
# TEXT_SPLITTER 名称
TEXT_SPLITTER = "ChineseRecursiveTextSplitter"
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 0
# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH)
# 知识库默认存储路径
KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base")
if not os.path.exists(KB_ROOT_PATH):
os.mkdir(KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
# 可选向量库类型及对应配置
kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
}
}
# 默认向量库类型。可选faiss, milvus, pg.
DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量
CACHED_VS_NUM = 1
# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
SCORE_THRESHOLD = 1
# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3
# 通常情况下不需要更改以下内容
# nltk 模型存储路径
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# 基于本地知识问答的提示词模版使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
PROMPT_TEMPLATE = """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>
VLLM_MODEL_DICT = {
"aquila-7b":"BAAI/Aquila-7B",
"aquilachat-7b":"BAAI/AquilaChat-7B",
<问题>{{ question }}</问题>"""
"baichuan-7b": "baichuan-inc/Baichuan-7B",
"baichuan-13b": "baichuan-inc/Baichuan-13B",
'baichuan-13b-chat':'baichuan-inc/Baichuan-13B-Chat',
# 注意bloom系列的tokenizer与model是分离的因此虽然vllm支持但与fschat框架不兼容
# "bloom":"bigscience/bloom",
# "bloomz":"bigscience/bloomz",
# "bloomz-560m":"bigscience/bloomz-560m",
# "bloomz-7b1":"bigscience/bloomz-7b1",
# "bloomz-1b7":"bigscience/bloomz-1b7",
# API 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain
OPEN_CROSS_DOMAIN = False
"internlm-7b":"internlm/internlm-7b",
"internlm-chat-7b":"internlm/internlm-chat-7b",
"falcon-7b":"tiiuae/falcon-7b",
"falcon-40b":"tiiuae/falcon-40b",
"falcon-rw-7b":"tiiuae/falcon-rw-7b",
"gpt2":"gpt2",
"gpt2-xl":"gpt2-xl",
"gpt-j-6b":"EleutherAI/gpt-j-6b",
"gpt4all-j":"nomic-ai/gpt4all-j",
"gpt-neox-20b":"EleutherAI/gpt-neox-20b",
"pythia-12b":"EleutherAI/pythia-12b",
"oasst-sft-4-pythia-12b-epoch-3.5":"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
"dolly-v2-12b":"databricks/dolly-v2-12b",
"stablelm-tuned-alpha-7b":"stabilityai/stablelm-tuned-alpha-7b",
"Llama-2-13b-hf":"meta-llama/Llama-2-13b-hf",
"Llama-2-70b-hf":"meta-llama/Llama-2-70b-hf",
"open_llama_13b":"openlm-research/open_llama_13b",
"vicuna-13b-v1.3":"lmsys/vicuna-13b-v1.3",
"koala":"young-geng/koala",
"mpt-7b":"mosaicml/mpt-7b",
"mpt-7b-storywriter":"mosaicml/mpt-7b-storywriter",
"mpt-30b":"mosaicml/mpt-30b",
"opt-66b":"facebook/opt-66b",
"opt-iml-max-30b":"facebook/opt-iml-max-30b",
# Bing 搜索必备变量
# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search
# 具体申请方式请见
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource
# 使用python创建bing api 搜索实例详见:
# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python
BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
# 注意不是bing Webmaster Tools的api key
"Qwen-7B":"Qwen/Qwen-7B",
"Qwen-14B":"Qwen/Qwen-14B",
"Qwen-7B-Chat":"Qwen/Qwen-7B-Chat",
"Qwen-14B-Chat":"Qwen/Qwen-14B-Chat",
# 此外如果是在服务器上报Failed to establish a new connection: [Errno 110] Connection timed out
# 是因为服务器加了防火墙需要联系管理员加白名单如果公司的服务器的话就别想了GG
BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
}

View File

@ -0,0 +1,23 @@
# prompt模板使用Jinja2语法简单点就是用双大括号代替f-string的单大括号
# 本配置文件支持热加载修改prompt模板后无需重启服务。
# LLM对话支持的变量
# - input: 用户输入内容
# 知识库和搜索引擎对话支持的变量:
# - context: 从检索结果拼接的知识文本
# - question: 用户提出的问题
PROMPT_TEMPLATES = {
# LLM对话模板
"llm_chat": "{{ input }}",
# 基于本地知识问答的提示词模
"knowledge_base_chat": """<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。 </指令>
<已知信息>{{ context }}</已知信息>
<问题>{{ question }}</问题>""",
}

View File

@ -1,5 +1,5 @@
from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
import httpx
import sys
from configs.model_config import LLM_DEVICE
# httpx 请求默认超时时间(秒)。如果加载模型或对话较慢,出现超时错误,可以适当加大该值。
HTTPX_DEFAULT_TIMEOUT = 300.0
@ -8,8 +8,8 @@ HTTPX_DEFAULT_TIMEOUT = 300.0
# is open cross domain
OPEN_CROSS_DOMAIN = False
# 各服务器默认绑定host
DEFAULT_BIND_HOST = "127.0.0.1"
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
DEFAULT_BIND_HOST = "0.0.0.0"
# webui.py server
WEBUI_SERVER = {
@ -26,25 +26,27 @@ API_SERVER = {
# fastchat openai_api server
FSCHAT_OPENAI_API = {
"host": DEFAULT_BIND_HOST,
"port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
"port": 20000,
}
# fastchat model_worker server
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
# 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。
# 在启动startup.py时可用通过`--model-worker --model-name xxxx`指定模型不指定则为LLM_MODEL
FSCHAT_MODEL_WORKERS = {
# 所有模型共用的默认配置,可在模型专项配置或llm_model_dict中进行覆盖。
# 所有模型共用的默认配置,可在模型专项配置中进行覆盖。
"default": {
"host": DEFAULT_BIND_HOST,
"port": 20002,
"device": LLM_DEVICE,
# False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题参见doc/FAQ
"infer_turbo": "vllm" if sys.platform.startswith("linux") else False,
# 多卡加载需要配置的参数
# "gpus": None, # 使用的GPU以str的格式指定如"0,1"
# model_worker多卡加载需要配置的参数
# "gpus": None, # 使用的GPU以str的格式指定如"0,1"如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定
# "num_gpus": 1, # 使用GPU的数量
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
# 以下为非常用参数,可根据需要配置
# 以下为model_worker非常用参数,可根据需要配置
# "load_8bit": False, # 开启8bit量化
# "cpu_offloading": None,
# "gptq_ckpt": None,
@ -60,21 +62,55 @@ FSCHAT_MODEL_WORKERS = {
# "stream_interval": 2,
# "no_register": False,
# "embed_in_truncate": False,
# 以下为vllm_woker配置参数,注意使用vllm必须有gpu仅在Linux测试通过
# tokenizer = model_path # 如果tokenizer与model_path不一致在此处添加
# 'tokenizer_mode':'auto',
# 'trust_remote_code':True,
# 'download_dir':None,
# 'load_format':'auto',
# 'dtype':'auto',
# 'seed':0,
# 'worker_use_ray':False,
# 'pipeline_parallel_size':1,
# 'tensor_parallel_size':1,
# 'block_size':16,
# 'swap_space':4 , # GiB
# 'gpu_memory_utilization':0.90,
# 'max_num_batched_tokens':2560,
# 'max_num_seqs':256,
# 'disable_log_stats':False,
# 'conv_template':None,
# 'limit_worker_concurrency':5,
# 'no_register':False,
# 'num_gpus': 1
# 'engine_use_ray': False,
# 'disable_log_requests': False
},
"baichuan-7b": { # 使用default中的IP和端口
"device": "cpu",
# 可以如下示例方式更改默认配置
# "baichuan-7b": { # 使用default中的IP和端口
# "device": "cpu",
# },
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口
"port": 21001,
},
"zhipu-api": { # 请为每个在线API设置不同的端口
"port": 20003,
"minimax-api": {
"port": 21002,
},
"minimax-api": { # 请为每个在线API设置不同的端口
"port": 20004,
},
"xinghuo-api": { # 请为每个在线API设置不同的端口
"port": 20005,
"xinghuo-api": {
"port": 21003,
},
"qianfan-api": {
"port": 20006,
"port": 21004,
},
"fangzhou-api": {
"port": 21005,
},
"qwen-api": {
"port": 21006,
},
}

View File

@ -107,7 +107,7 @@ embedding_model_dict = {
Q9: 执行 `python cli_demo.py`过程中,显卡内存爆了,提示 "OutOfMemoryError: CUDA out of memory"
A9: 将 `VECTOR_SEARCH_TOP_K``LLM_HISTORY_LEN` 的值调低,比如 `VECTOR_SEARCH_TOP_K = 5``LLM_HISTORY_LEN = 2`,这样由 `query``context` 拼接得到的 `prompt` 会变短,会减少内存的占用。或者打开量化,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对`LOAD_IN_8BIT`参数进行修改
A9: 将 `VECTOR_SEARCH_TOP_K``LLM_HISTORY_LEN` 的值调低,比如 `VECTOR_SEARCH_TOP_K = 5``LLM_HISTORY_LEN = 2`,这样由 `query``context` 拼接得到的 `prompt` 会变短,会减少内存的占用。或者打开量化,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对 `LOAD_IN_8BIT`参数进行修改
---
@ -171,7 +171,6 @@ Q14: 修改配置中路径后,加载 text2vec-large-chinese 依然提示 `WARN
A14: 尝试更换 embedding如 text2vec-base-chinese请在 [configs/model_config.py](../configs/model_config.py) 文件中,修改 `text2vec-base`参数为本地路径,绝对路径或者相对路径均可
---
Q15: 使用pg向量库建表报错
@ -182,4 +181,43 @@ A15: 需要手动安装对应的vector扩展(连接pg执行 CREATE EXTENSION IF
Q16: pymilvus 连接超时
A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3
A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3
Q16: 使用vllm推理加速框架时已经下载了模型但出现HuggingFace通信问题
A16: 参照如下代码修改python环境下/site-packages/vllm/model_executor/weight_utils.py文件的prepare_hf_model_weights函数如下对应代码
```python
if not is_local:
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
model_path_temp = os.path.join(
os.getenv("HOME"),
".cache/huggingface/hub",
"models--" + model_name_or_path.replace("/", "--"),
"snapshots/",
)
downloaded = False
if os.path.exists(model_path_temp):
temp_last_dir = os.listdir(model_path_temp)[-1]
model_path_temp = os.path.join(model_path_temp, temp_last_dir)
base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin")
files = glob.glob(base_pattern)
if len(files) > 0:
downloaded = True
if downloaded:
hf_folder = model_path_temp
else:
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
```

80
docs/自定义Agent.md Normal file
View File

@ -0,0 +1,80 @@
## 自定义属于自己的Agent
### 1. 创建自己的Agent工具
+ 开发者在```server/agent```文件中创建一个自己的文件,并将其添加到```tools.py```中。这样就完成了Tools的设定。
+ 当您创建了一个```custom_agent.py```文件,其中包含一个```work```函数,那么您需要在```tools.py```中添加如下代码:
```python
from custom_agent import work
Tool.from_function(
func=work,
name="该函数的名字",
description=""
)
```
+ 请注意如果你确定在某一个工程中不会使用到某个工具可以将其从Tools中移除降低模型分类错误导致使用错误工具的风险。
### 2. 修改 custom_template.py文件
开发者需要根据自己选择的大模型设定适合该模型的Agent Prompt和自自定义返回格式。
在我们的代码中提供了默认的两种方式一种是适配于GPT和Qwen的提示词
```python
"""
Answer the following questions as best you can. You have access to the following tools:
{tools}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
history:
{history}
Question: {input}
Thought: {agent_scratchpad}
"""
```
另一种是适配于GLM-130B的提示词
```python
"""
尽可能地回答以下问题。你可以使用以下工具:{tools}
请按照以下格式进行:
Question: 需要你回答的输入问题
Thought: 你应该总是思考该做什么
Action: 需要使用的工具,应该是[{tool_names}]中的一个
Action Input: 传入工具的内容
Observation: 行动的结果
... (这个Thought/Action/Action Input/Observation可以重复N次)
Thought: 我现在知道最后的答案
Final Answer: 对原始输入问题的最终答案
现在开始!
之前的对话:
{history}
New question: {input}
Thought: {agent_scratchpad}
"""
```
### 3. 局限性
1. 在我们的实验中小于70B级别的模型若不经过微调很难达到较好的效果。因此我们建议开发者使用大于70B级别的模型进行微调以达到更好的效果。
2. 由于Agent的脆弱性temperture参数的设置对于模型的效果有很大的影响。我们建议开发者在使用自定义Agent时对于不同的模型将其设置成0.1以下,以达到更好的效果。
3. 即使使用了大于70B级别的模型开发者也应该在Prompt上进行深度优化以让模型能成功的选择工具并完成任务。
### 4. 我们已经支持的Agent
我们为开发者编写了三个运用大模型执行的Agent分别是
1. 翻译工具,实现对输入的任意语言翻译。
2. 数学工具使用LLMMathChain 实现数学计算。
3. 天气工具使用自定义的LLMWetherChain实现天气查询调用和风天气API。
4. 我们支持Langchain支持的Agent工具在代码中我们已经提供了Shell和Google Search两个工具的实现。

BIN
img/official_account.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 260 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 220 KiB

View File

@ -1,44 +1,92 @@
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, recreate_all_vs, list_kbs_from_folder
from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files
from configs.model_config import NLTK_DATA_PATH
import nltk
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
from startup import dump_server_info
from datetime import datetime
import sys
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.formatter_class = argparse.RawTextHelpFormatter
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
parser.add_argument(
"-r",
"--recreate-vs",
action="store_true",
help=('''
recreate all vector store.
recreate vector store.
use this option if you have copied document files to the content folder, but vector store has not been populated or DEFAUL_VS_TYPE/EMBEDDING_MODEL changed.
if your vector store is ready with the configs, just skip this option to fill info to database only.
'''
)
)
args = parser.parse_args()
parser.add_argument(
"-u",
"--update-in-db",
action="store_true",
help=('''
update vector store for files exist in database.
use this option if you want to recreate vectors for files exist in db and skip files exist in local folder only.
'''
)
)
parser.add_argument(
"-i",
"--increament",
action="store_true",
help=('''
update vector store for files exist in local folder and not exist in database.
use this option if you want to create vectors increamentally.
'''
)
)
parser.add_argument(
"--prune-db",
action="store_true",
help=('''
delete docs in database that not existed in local folder.
it is used to delete database docs after user deleted some doc files in file browser
'''
)
)
parser.add_argument(
"--prune-folder",
action="store_true",
help=('''
delete doc files in local folder that not existed in database.
is is used to free local disk space by delete unused doc files.
'''
)
)
parser.add_argument(
"--kb-name",
type=str,
nargs="+",
default=[],
help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.")
)
dump_server_info()
start_time = datetime.now()
if args.recreate_vs:
reset_tables()
print("database talbes reseted")
print("recreating all vector stores")
recreate_all_vs()
if len(sys.argv) <= 1:
parser.print_help()
else:
create_tables()
print("database talbes created")
print("filling kb infos to database")
for kb in list_kbs_from_folder():
folder2db(kb, "fill_info_only")
args = parser.parse_args()
start_time = datetime.now()
end_time = datetime.now()
print(f"总计用时: {end_time-start_time}")
create_tables() # confirm tables exist
if args.recreate_vs:
reset_tables()
print("database talbes reseted")
print("recreating all vector stores")
folder2db(kb_names=args.kb_name, mode="recreate_vs")
elif args.update_in_db:
folder2db(kb_names=args.kb_name, mode="update_in_db")
elif args.increament:
folder2db(kb_names=args.kb_name, mode="increament")
elif args.prune_db:
prune_db_docs(args.kb_name)
elif args.prune_folder:
prune_folder_files(args.kb_name)
end_time = datetime.now()
print(f"总计用时: {end_time-start_time}")

View File

@ -1,10 +1,12 @@
langchain==0.0.287
fschat[model_worker]==0.2.28
langchain>=0.0.302
fschat[model_worker]==0.2.29
openai
sentence_transformers
transformers>=4.31.0
torch~=2.0.0
fastapi~=0.99.1
transformers>=4.33.0
torch>=2.0.1
torchvision
torchaudio
fastapi>=0.103.1
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
@ -23,6 +25,12 @@ pathlib
pytest
scikit-learn
numexpr
vllm==0.1.7; sys_platform == "linux"
# online api libs
# zhipuai
# dashscope>=1.10.0 # qwen
# qianfan
# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
@ -34,9 +42,13 @@ pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox >=1.1.6, <=1.1.7
streamlit-chatbox>=1.1.9
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog
tqdm
websockets
tiktoken
einops
scipy
transformers_stream_generator==0.0.4

View File

@ -1,10 +1,12 @@
langchain==0.0.287
fschat[model_worker]==0.2.28
langchain>=0.0.302
fschat[model_worker]==0.2.29
openai
sentence_transformers
transformers>=4.31.0
torch~=2.0.0
fastapi~=0.99.1
transformers>=4.33.0
torch >=2.0.1
torchvision
torchaudio
fastapi>=0.103.1
nltk~=3.8.1
uvicorn~=0.23.1
starlette~=0.27.0
@ -17,12 +19,19 @@ accelerate
spacy
PyMuPDF==1.22.5
rapidocr_onnxruntime>=1.3.2
requests
pathlib
pytest
scikit-learn
numexpr
vllm==0.1.7; sys_platform == "linux"
# online api libs
# zhipuai
# dashscope>=1.10.0 # qwen
# qianfan
# volcengine>=1.0.106 # fangzhou
# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3

View File

@ -3,7 +3,7 @@ pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox >=1.1.6, <=1.1.7
streamlit-chatbox>=1.1.9
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
nltk

109
server/agent/callbacks.py Normal file
View File

@ -0,0 +1,109 @@
from uuid import UUID
from langchain.callbacks import AsyncIteratorCallbackHandler
import json
import asyncio
from typing import Any, Dict, List, Optional
from langchain.schema import AgentFinish, AgentAction
from langchain.schema.output import LLMResult
def dumps(obj: Dict) -> str:
return json.dumps(obj, ensure_ascii=False)
class Status:
start: int = 1
running: int = 2
complete: int = 3
agent_action: int = 4
agent_finish: int = 5
error: int = 6
make_tool: int = 7
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def __init__(self):
super().__init__()
self.queue = asyncio.Queue()
self.done = asyncio.Event()
self.cur_tool = {}
self.out = True
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None,
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
self.cur_tool = {
"tool_name": serialized["name"],
"input_str": input_str,
"output_str": "",
"status": Status.agent_action,
"run_id": run_id.hex,
"llm_token": "",
"final_answer": "",
"error": "",
}
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
tags: List[str] | None = None, **kwargs: Any) -> None:
self.out = True
self.cur_tool.update(
status=Status.agent_finish,
output_str=output.replace("Answer:", ""),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.error,
error=str(error),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if token:
if "Action" in token:
self.out = False
self.cur_tool.update(
status=Status.running,
llm_token="\n\n",
)
self.queue.put_nowait(dumps(self.cur_tool))
if self.out:
self.cur_tool.update(
status=Status.running,
llm_token=token,
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
self.cur_tool.update(
status=Status.start,
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.out = True
self.cur_tool.update(
status=Status.complete,
llm_token="",
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
self.out = True
self.cur_tool.update(
status=Status.error,
error=str(error),
)
self.queue.put_nowait(dumps(self.cur_tool))
async def on_agent_finish(
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
self.cur_tool = {}

View File

@ -0,0 +1,64 @@
from langchain.agents import Tool, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from typing import List, Union
from langchain.schema import AgentAction, AgentFinish
import re
class CustomPromptTemplate(StringPromptTemplate):
# The template to use
template: str
# The list of tools available
tools: List[Tool]
def format(self, **kwargs) -> str:
# Get the intermediate steps (AgentAction, Observation tuples)
# Format them in a particular way
intermediate_steps = kwargs.pop("intermediate_steps")
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought: "
# Set the agent_scratchpad variable to that value
kwargs["agent_scratchpad"] = thoughts
# Create a tools variable from the list of tools provided
kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
# Create a list of tool names for the tools provided
kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
return self.template.format(**kwargs)
class CustomOutputParser(AgentOutputParser):
def parse(self, llm_output: str) -> AgentFinish | AgentAction | str:
# Check if agent should finish
if "Final Answer:" in llm_output:
return AgentFinish(
# Return values is generally always a dictionary with a single `output` key
# It is not recommended to try anything else at the moment :)
return_values={"output": llm_output.replace("Final Answer:", "").strip()},
log=llm_output,
)
# Parse out the action and action input
regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
match = re.search(regex, llm_output, re.DOTALL)
if not match:
return AgentFinish(
return_values={"output": f"调用agent失败: `{llm_output}`"},
log=llm_output,
)
action = match.group(1).strip()
action_input = match.group(2)
# Return the action and action input
try:
ans = AgentAction(
tool=action,
tool_input=action_input.strip(" ").strip('"'),
log=llm_output
)
return ans
except:
return AgentFinish(
return_values={"output": f"调用agent失败: `{llm_output}`"},
log=llm_output,
)

View File

@ -0,0 +1,8 @@
import os
os.environ["GOOGLE_CSE_ID"] = ""
os.environ["GOOGLE_API_KEY"] = ""
from langchain.tools import GoogleSearchResults
def google_search(query: str):
tool = GoogleSearchResults()
return tool.run(tool_input=query)

70
server/agent/math.py Normal file
View File

@ -0,0 +1,70 @@
from langchain.prompts import PromptTemplate
from langchain.chains import LLMMathChain
from server.utils import wrap_done, get_ChatOpenAI
from configs.model_config import LLM_MODEL, TEMPERATURE
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.manager import CallbackManagerForToolRun
_PROMPT_TEMPLATE = """将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。
问题: ${{包含数学问题的问题}}
```text
${{解决问题的单行数学表达式}}
```
...numexpr.evaluate(query)...
```output
${{运行代码的输出}}
```
答案: ${{答案}}
这是两个例子
问题: 37593 * 67是多少
```text
37593 * 67
```
...numexpr.evaluate("37593 * 67")...
```output
2518731
答案: 2518731
问题: 37593的五次方根是多少
```text
37593**(1/5)
```
...numexpr.evaluate("37593**(1/5)")...
```output
8.222831614237718
答案: 8.222831614237718
问题: 2的平方是多少
```text
2 ** 2
```
...numexpr.evaluate("2 ** 2")...
```output
4
答案: 4
现在这是我的问题
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)
def calculate(query: str):
model = get_ChatOpenAI(
streaming=False,
model_name=LLM_MODEL,
temperature=TEMPERATURE,
)
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_math.run(query)
return ans

5
server/agent/shell.py Normal file
View File

@ -0,0 +1,5 @@
from langchain.tools import ShellTool
def shell(query: str):
tool = ShellTool()
return tool.run(tool_input=query)

40
server/agent/tools.py Normal file
View File

@ -0,0 +1,40 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from server.agent.math import calculate
from server.agent.translator import translate
from server.agent.weather import weathercheck
from server.agent.shell import shell
from server.agent.google_search import google_search
from langchain.agents import Tool
tools = [
Tool.from_function(
func=calculate,
name="计算器工具",
description="进行简单的数学运算"
),
Tool.from_function(
func=translate,
name="翻译工具",
description="翻译各种语言"
),
Tool.from_function(
func=weathercheck,
name="天气查询工具",
description="查询天气",
),
Tool.from_function(
func=shell,
name="shell工具",
description="使用命令行工具输出",
),
Tool.from_function(
func=google_search,
name="谷歌搜索工具",
description="使用谷歌搜索",
)
]
tool_names = [tool.name for tool in tools]

View File

@ -0,0 +1,55 @@
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import sys
import os
from server.utils import get_ChatOpenAI
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from langchain.chains.llm_math.prompt import PROMPT
from configs.model_config import LLM_MODEL,TEMPERATURE
_PROMPT_TEMPLATE = '''
# 指令
接下来作为一个专业的翻译专家当我给出句子或段落时你将提供通顺且具有可读性的对应语言的翻译注意
1. 确保翻译结果流畅且易于理解
2. 无论提供的是陈述句或疑问句只进行翻译
3. 不添加与原文无关的内容
原文: ${{用户需要翻译的原文和目标语言}}
{question}
```output
${{翻译结果}}
```
答案: ${{答案}}
以下是两个例子
问题: 翻译13成英语
```text
13 英语
```output
thirteen
以下是两个例子
问题: 翻译 我爱你 成法语
```text
13 法语
```output
Je t'aime.
'''
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)
def translate(query: str):
model = get_ChatOpenAI(
streaming=False,
model_name=LLM_MODEL,
temperature=TEMPERATURE,
)
llm_translate = LLMChain(llm=model, prompt=PROMPT)
ans = llm_translate.run(query)
return ans

365
server/agent/weather.py Normal file
View File

@ -0,0 +1,365 @@
## 使用和风天气API查询天气
from __future__ import annotations
## 单独运行的时候需要添加
import sys
import os
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from server.utils import get_ChatOpenAI
import re
import warnings
from typing import Dict
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
import requests
from typing import List, Any, Optional
from configs.model_config import LLM_MODEL, TEMPERATURE
## 使用和风天气API查询天气
KEY = ""
def get_city_info(location, adm, key):
base_url = 'https://geoapi.qweather.com/v2/city/lookup?'
params = {'location': location, 'adm': adm, 'key': key}
response = requests.get(base_url, params=params)
data = response.json()
return data
from datetime import datetime
def format_weather_data(data):
hourly_forecast = data['hourly']
formatted_data = ''
for forecast in hourly_forecast:
# 将预报时间转换为datetime对象
forecast_time = datetime.strptime(forecast['fxTime'], '%Y-%m-%dT%H:%M%z')
# 获取预报时间的时区
forecast_tz = forecast_time.tzinfo
# 获取当前时间(使用预报时间的时区)
now = datetime.now(forecast_tz)
# 计算预报日期与当前日期的差值
days_diff = (forecast_time.date() - now.date()).days
if days_diff == 0:
forecast_date_str = '今天'
elif days_diff == 1:
forecast_date_str = '明天'
elif days_diff == 2:
forecast_date_str = '后天'
else:
forecast_date_str = str(days_diff) + '天后'
forecast_time_str = forecast_date_str + ' ' + forecast_time.strftime('%H:%M')
# 计算预报时间与当前时间的差值
time_diff = forecast_time - now
# 将差值转换为小时
hours_diff = time_diff.total_seconds() // 3600
if hours_diff < 1:
hours_diff_str = '1小时后'
elif hours_diff >= 24:
# 如果超过24小时转换为天数
days_diff = hours_diff // 24
hours_diff_str = str(int(days_diff)) + '天后'
else:
hours_diff_str = str(int(hours_diff)) + '小时后'
# 将预报时间和当前时间的差值添加到输出中
formatted_data += '预报时间: ' + hours_diff_str + '\n'
formatted_data += '具体时间: ' + forecast_time_str + '\n'
formatted_data += '温度: ' + forecast['temp'] + '°C\n'
formatted_data += '天气: ' + forecast['text'] + '\n'
formatted_data += '风向: ' + forecast['windDir'] + '\n'
formatted_data += '风速: ' + forecast['windSpeed'] + '\n'
formatted_data += '湿度: ' + forecast['humidity'] + '%\n'
formatted_data += '降水概率: ' + forecast['pop'] + '%\n'
# formatted_data += '降水量: ' + forecast['precip'] + 'mm\n'
formatted_data += '\n\n'
return formatted_data
def get_weather(key, location_id, time: str = "24"):
if time:
url = "https://devapi.qweather.com/v7/weather/" + time + "h?"
else:
time = "3" # 免费订阅只能查看3天的天气
url = "https://devapi.qweather.com/v7/weather/" + time + "d?"
params = {
'location': location_id,
'key': key,
}
response = requests.get(url, params=params)
data = response.json()
return format_weather_data(data)
def split_query(query):
parts = query.split()
location = parts[0] if parts[0] != 'None' else parts[1]
adm = parts[1]
time = parts[2]
return location, adm, time
def weather(query):
location, adm, time = split_query(query)
key = KEY
if time != "None" and int(time) > 24:
return "只能查看24小时内的天气无法回答"
if time == "None":
time = "24" # 免费的版本只能24小时内的天气
if key == "":
return "请先在代码中填入和风天气API Key"
city_info = get_city_info(location=location, adm=adm, key=key)
location_id = city_info['location'][0]['id']
weather_data = get_weather(key=key, location_id=location_id, time=time)
return weather_data
class LLMWeatherChain(Chain):
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate
"""[Deprecated] Prompt to use to translate to python if necessary."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMWeatherChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _evaluate_expression(self, expression: str) -> str:
try:
output = weather(expression)
except Exception as e:
output = "输入的信息有误,请再次尝试"
# raise ValueError(f"错误: {expression},输入的信息不对")
return output
def _process_llm_result(
self, llm_output: str, run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
elif "Answer:" in llm_output:
answer = "Answer: " + llm_output.split("Answer:")[-1]
else:
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
llm_output = self.llm_chain.predict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
llm_output = await self.llm_chain.apredict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
return await self._aprocess_llm_result(llm_output, _run_manager)
@property
def _chain_type(self) -> str:
return "llm_weather_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
**kwargs: Any,
) -> LLMWeatherChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
from langchain.prompts import PromptTemplate
_PROMPT_TEMPLATE = """用户将会向您咨询天气问题,您不需要自己回答天气问题,而是将用户提问的信息提取出来区,市和时间三个元素后使用我为你编写好的工具进行查询并返回结果,格式为 区+市+时间 每个元素用空格隔开。如果缺少信息,则用 None 代替。
问题: ${{用户的问题}}
```text
${{拆分的区市和时间}}
```
... weather(提取后的关键字用空格隔开)...
```output
${{提取后的答案}}
```
答案: ${{答案}}
这是两个例子
问题: 上海浦东未来1小时天气情况
```text
浦东 上海 1
```
...weather(浦东 上海 1)...
```output
预报时间: 1小时后
具体时间: 今天 18:00
温度: 24°C
天气: 多云
风向: 西南风
风速: 7
湿度: 88%
降水概率: 16%
Answer:
预报时间: 1小时后
具体时间: 今天 18:00
温度: 24°C
天气: 多云
风向: 西南风
风速: 7
湿度: 88%
降水概率: 16%
问题: 北京市朝阳区未来24小时天气如何
```text
朝阳 北京 24
```
...weather(朝阳 北京 24)...
```output
预报时间: 23小时后
具体时间: 明天 17:00
温度: 26°C
天气:
风向: 西南风
风速: 11
湿度: 65%
降水概率: 20%
Answer:
预报时间: 23小时后
具体时间: 明天 17:00
温度: 26°C
天气:
风向: 西南风
风速: 11
湿度: 65%
降水概率: 20%
现在这是我的问题
问题: {question}
"""
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)
def weathercheck(query: str):
model = get_ChatOpenAI(
streaming=False,
model_name=LLM_MODEL,
temperature=TEMPERATURE,
)
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_weather.run(query)
return ans
if __name__ == '__main__':
## 检测api是否能正确返回
query = "上海浦东未来1小时天气情况"
# ans = weathercheck(query)
ans = weather("浦东 上海 1")
print(ans)

View File

@ -12,12 +12,12 @@ import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
search_engine_chat, agent_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.llm_api import list_llm_models, change_llm_model, stop_llm_model
from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
from typing import List
@ -67,6 +67,10 @@ def create_app():
tags=["Chat"],
summary="与搜索引擎对话")(search_engine_chat)
app.post("/chat/agent_chat",
tags=["Chat"],
summary="与agent对话")(agent_chat)
# Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
@ -125,20 +129,25 @@ def create_app():
)(recreate_vector_store)
# LLM模型相关接口
app.post("/llm_model/list_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型",
)(list_llm_models)
app.post("/llm_model/list_running_models",
tags=["LLM Model Management"],
summary="列出当前已加载的模型",
)(list_running_models)
app.post("/llm_model/list_config_models",
tags=["LLM Model Management"],
summary="列出configs已配置的模型",
)(list_config_models)
app.post("/llm_model/stop",
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)(stop_llm_model)
tags=["LLM Model Management"],
summary="停止指定的LLM模型Model Worker)",
)(stop_llm_model)
app.post("/llm_model/change",
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)(change_llm_model)
tags=["LLM Model Management"],
summary="切换指定的LLM模型Model Worker)",
)(change_llm_model)
return app

View File

@ -2,3 +2,4 @@ from .chat import chat
from .knowledge_base_chat import knowledge_base_chat
from .openai_chat import openai_chat
from .search_engine_chat import search_engine_chat
from .agent_chat import agent_chat

126
server/chat/agent_chat.py Normal file
View File

@ -0,0 +1,126 @@
from langchain.memory import ConversationBufferWindowMemory
from server.agent.tools import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status, dumps
from langchain.agents import AgentExecutor, LLMSingleActionAgent
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
from typing import AsyncIterable, Optional
import asyncio
from typing import List
from server.chat.utils import History
import json
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
prompt_name: str = Body("agent_chat",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
):
history = [History.from_data(h) for h in history]
async def agent_chat_iterator(
query: str,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = CustomAsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
)
prompt_template = CustomPromptTemplate(
template=get_prompt_template(prompt_name),
tools=tools,
input_variables=["input", "intermediate_steps", "history"]
)
output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt_template)
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["Observation:", "Observation:\n", "<|im_end|>"], # Qwen模型中使用这个
# stop=["Observation:", "Observation:\n"], # 其他模型,注意模板
allowed_tools=tool_names,
)
# 把history转成agent的memory
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
memory.chat_memory.add_user_message(message.content)
else:
# 添加AI消息
memory.chat_memory.add_ai_message(message.content)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
task = asyncio.create_task(wrap_done(
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
callback.done),
)
if stream:
async for chunk in callback.aiter():
tools_use = []
# Use server-sent-events to stream the response
data = json.loads(chunk)
if data["status"] == Status.error:
tools_use.append("工具调用失败:\n" + data["error"])
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
yield json.dumps({"answer": "(工具调用失败,请查看工具栏报错) \n\n"}, ensure_ascii=False)
if data["status"] == Status.start or data["status"] == Status.complete:
continue
if data["status"] == Status.agent_action:
yield json.dumps({"answer": "(正在使用工具,请注意工具栏变化) \n\n"}, ensure_ascii=False)
if data["status"] == Status.agent_finish:
tools_use.append("工具名称: " + data["tool_name"])
tools_use.append("工具输入: " + data["input_str"])
tools_use.append("工具输出: " + data["output_str"])
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
else:
pass
# agent必须要steram=True,这部分暂时没有完成
# result = []
# async for chunk in callback.aiter():
# data = json.loads(chunk)
# status = data["status"]
# if status == Status.start:
# result.append(chunk)
# elif status == Status.running:
# result[-1]["llm_token"] += chunk["llm_token"]
# elif status == Status.complete:
# result[-1]["status"] = Status.complete
# elif status == Status.agent_finish:
# result.append(chunk)
# elif status == Status.agent_finish:
# pass
# yield dumps(result)
await task
return StreamingResponse(agent_chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")

View File

@ -1,15 +1,15 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE
from server.chat.utils import wrap_done
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from configs import LLM_MODEL, TEMPERATURE
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List
from server.chat.utils import History
from server.utils import get_prompt_template
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
@ -21,29 +21,26 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
history = [History.from_data(h) for h in history]
async def chat_iterator(query: str,
history: List[History] = [],
model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
callbacks=[callback],
)
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
prompt_template = get_prompt_template(prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
@ -66,5 +63,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
await task
return StreamingResponse(chat_iterator(query, history, model_name),
return StreamingResponse(chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")

View File

@ -1,12 +1,9 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
TEMPERATURE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
@ -33,7 +30,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
):
@ -44,27 +42,22 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(query: str,
kb: KBService,
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
callbacks=[callback],
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
prompt_template = get_prompt_template(prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
@ -102,5 +95,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
await task
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history, model_name),
return StreamingResponse(knowledge_base_chat_iterator(query=query,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")

View File

@ -1,7 +1,8 @@
from fastapi.responses import StreamingResponse
from typing import List
import openai
from configs.model_config import llm_model_dict, LLM_MODEL, logger, log_verbose
from configs import LLM_MODEL, logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address
from pydantic import BaseModel
@ -23,9 +24,10 @@ class OpenAiChatMsgIn(BaseModel):
async def openai_chat(msg: OpenAiChatMsgIn):
openai.api_key = llm_model_dict[LLM_MODEL]["api_key"]
config = get_model_worker_config(msg.model)
openai.api_key = config.get("api_key", "EMPTY")
print(f"{openai.api_key=}")
openai.api_base = llm_model_dict[LLM_MODEL]["api_base_url"]
openai.api_base = config.get("api_base_url", fschat_openai_api_address())
print(f"{openai.api_base=}")
print(msg)

View File

@ -1,14 +1,12 @@
from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
LLM_MODEL, SEARCH_ENGINE_TOP_K, TEMPERATURE)
from fastapi import Body
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K,
PROMPT_TEMPLATE, TEMPERATURE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
@ -73,7 +71,8 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -88,23 +87,20 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
top_k: int,
history: Optional[List[History]],
model_name: str = LLM_MODEL,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
callbacks=[callback],
)
docs = await lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
prompt_template = get_prompt_template(prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
@ -135,5 +131,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history, model_name),
return StreamingResponse(search_engine_chat_iterator(query=query,
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")

View File

@ -1,22 +1,7 @@
import asyncio
from typing import Awaitable, List, Tuple, Dict, Union
from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
await fn
except Exception as e:
# TODO: handle exception
msg = f"Caught exception: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
finally:
# Signal the aiter to stop.
event.set()
from typing import List, Tuple, Dict, Union
class History(BaseModel):

View File

@ -2,7 +2,7 @@ from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from configs.model_config import SQLALCHEMY_DATABASE_URI
from configs import SQLALCHEMY_DATABASE_URI
import json

View File

@ -3,7 +3,7 @@ from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_base_repository import list_kbs_from_db
from configs.model_config import EMBEDDING_MODEL, logger, log_verbose
from configs import EMBEDDING_MODEL, logger, log_verbose
from fastapi import Body

View File

@ -4,9 +4,9 @@ from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
import threading
from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE,
embedding_model_dict, logger, log_verbose)
from server.utils import embedding_device
from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM,
logger, log_verbose)
from server.utils import embedding_device, get_model_path
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
@ -22,7 +22,11 @@ class ThreadSafeObject:
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}>"
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
@property
def key(self):
return self._key
@contextmanager
def acquire(self, owner: str = "", msg: str = ""):
@ -30,13 +34,13 @@ class ThreadSafeObject:
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self._key)
self._pool._cache.move_to_end(self.key)
if log_verbose:
logger.info(f"{owner} 开始操作:{self._key}{msg}")
logger.info(f"{owner} 开始操作:{self.key}{msg}")
yield self._obj
finally:
if log_verbose:
logger.info(f"{owner} 结束操作:{self._key}{msg}")
logger.info(f"{owner} 结束操作:{self.key}{msg}")
self._lock.release()
def start_loading(self):
@ -118,15 +122,24 @@ class EmbeddingsPool(CachePool):
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE)
embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
if 'zh' in model:
# for chinese model
query_instruction = "为这个句子生成表示以用于检索相关文章:"
elif 'en' in model:
# for english model
query_instruction = "Represent this sentence for searching relevant passages:"
else:
# maybe ReRanker or else, just use empty string instead
query_instruction = ""
embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device},
query_instruction="为这个句子生成表示以用于检索相关文章:")
model_kwargs={'device': device},
query_instruction=query_instruction)
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device})
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()
else:

View File

@ -7,7 +7,7 @@ import os
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
@ -17,7 +17,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self._key} 保存到磁盘")
logger.info(f"已将向量库 {self.key} 保存到磁盘")
return ret
def clear(self):
@ -27,7 +27,7 @@ class ThreadSafeFaiss(ThreadSafeObject):
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self._key} 清空")
logger.info(f"已将向量库 {self.key} 清空")
return ret
@ -58,21 +58,22 @@ class _FaissPool(CachePool):
class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
self,
kb_name: str,
vector_name: str = "vector_store",
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' from disk.")
vs_path = get_vs_path(kb_name)
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device)
@ -89,7 +90,7 @@ class KBFaissPool(_FaissPool):
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
return self.get((kb_name, vector_name))
class MemoFaissPool(_FaissPool):
@ -144,7 +145,7 @@ if __name__ == "__main__":
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear()
threads = []
for n in range(1, 30):
t = threading.Thread(target=worker,
@ -152,6 +153,6 @@ if __name__ == "__main__":
daemon=True)
t.start()
threads.append(t)
for t in threads:
t.join()

View File

@ -1,10 +1,10 @@
import os
import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose,)
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose,)
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder,get_file_path,
files2docs_in_thread, KnowledgeFile)
@ -122,10 +122,10 @@ def upload_docs(files: List[UploadFile] = File(..., description="上传文件,
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: Json = Form({}, description="自定义的docs需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
@ -205,12 +205,12 @@ def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]),
def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
docs: Json = Body({}, description="自定义的docs需要转为json字符串", examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
'''
@ -323,6 +323,7 @@ def recreate_vector_store(
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
):
'''
recreate vector store from the content.
@ -366,5 +367,7 @@ def recreate_vector_store(
"msg": msg,
})
i += 1
if not not_refresh_vs_cache:
kb.save_vector_store()
return StreamingResponse(output(), media_type="text/event-stream")

View File

@ -18,8 +18,8 @@ from server.db.repository.knowledge_file_repository import (
list_docs_from_db,
)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_MODEL)
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_MODEL)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder,

View File

@ -1,7 +1,7 @@
import os
import shutil
from configs.model_config import (
from configs import (
KB_ROOT_PATH,
SCORE_THRESHOLD,
logger, log_verbose,
@ -18,18 +18,21 @@ from server.utils import torch_gc
class FaissKBService(KBService):
vs_path: str
kb_path: str
vector_name: str = "vector_store"
def vs_type(self) -> str:
return SupportedVSType.FAISS
def get_vs_path(self):
return os.path.join(self.get_kb_path(), "vector_store")
return os.path.join(self.get_kb_path(), self.vector_name)
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model)
def save_vector_store(self):
self.load_vector_store().save(self.vs_path)

View File

@ -7,7 +7,7 @@ from langchain.schema import Document
from langchain.vectorstores import Milvus
from sklearn.preprocessing import normalize
from configs.model_config import SCORE_THRESHOLD, kbs_config
from configs import SCORE_THRESHOLD, kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
@ -22,9 +22,9 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
def save_vector_store(self):
if self.milvus.col:
self.milvus.col.flush()
# def save_vector_store(self):
# if self.milvus.col:
# self.milvus.col.flush()
def get_doc_by_id(self, id: str) -> Optional[Document]:
if self.milvus.col:

View File

@ -7,7 +7,7 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config
from configs import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process

View File

@ -1,9 +1,10 @@
from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
logger, log_verbose)
from configs import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
CHUNK_SIZE, OVERLAP_SIZE,
logger, log_verbose)
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder,files2docs_in_thread,
KnowledgeFile,)
from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.base import Base, engine
import os
@ -33,33 +34,23 @@ def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
def folder2db(
kb_name: str,
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
kb_names: List[str],
mode: Literal["recreate_vs", "update_in_db", "increament"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size: int = -1,
chunk_overlap: int = -1,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = CHUNK_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
):
'''
use existed files in local folder to populate database and/or vector store.
set parameter `mode` to:
recreate_vs: recreate all vector store and fill info to database using existed files in local folder
fill_info_only: do not create vector store, fill info to db using existed files only
fill_info_only(disabled): do not create vector store, fill info to db using existed files only
update_in_db: update vector store and database info using local files that existed in database only
increament: create vector store and database info for local files that not existed in database only
'''
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
kb.create_kb()
if mode == "recreate_vs":
files_count = kb.count_files()
print(f"知识库 {kb_name} 中共有 {files_count} 个文档。\n即将清除向量库。")
kb.clear_vs()
files_count = kb.count_files()
print(f"清理后,知识库 {kb_name} 中共有 {files_count} 个文档。")
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]):
for success, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
@ -68,84 +59,77 @@ def folder2db(
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
kb_file.splited_docs = docs
kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True)
else:
print(result)
kb.save_vector_store()
elif mode == "fill_info_only":
files = list_files_from_folder(kb_name)
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
add_file_to_db(kb_file)
print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
elif mode == "update_in_db":
files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files)
kb_names = kb_names or list_kbs_from_folder()
for kb_name in kb_names:
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
kb.create_kb()
for kb_file in kb_files:
kb.update_doc(kb_file, not_refresh_vs_cache=True)
kb.save_vector_store()
elif mode == "increament":
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
for success, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if success:
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb.add_doc(kb_file=kb_file, docs=docs, not_refresh_vs_cache=True)
else:
print(result)
kb.save_vector_store()
else:
print(f"unspported migrate mode: {mode}")
# 清除向量库,从本地文件重建
if mode == "recreate_vs":
kb.clear_vs()
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
files2vs(kb_name, kb_files)
kb.save_vector_store()
# # 不做文件内容的向量化,仅将文件元信息存到数据库
# # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。
# elif mode == "fill_info_only":
# files = list_files_from_folder(kb_name)
# kb_files = file_to_kbfile(kb_name, files)
# for kb_file in kb_files:
# add_file_to_db(kb_file)
# print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
# 以数据库中文件列表为基准,利用本地文件更新向量库
elif mode == "update_in_db":
files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files)
kb.save_vector_store()
# 对比本地目录与数据库中的文件列表,进行增量向量化
elif mode == "increament":
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files)
kb.save_vector_store()
else:
print(f"unspported migrate mode: {mode}")
def recreate_all_vs(
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_mode: str = EMBEDDING_MODEL,
**kwargs: Any,
):
def prune_db_docs(kb_names: List[str]):
'''
used to recreate a vector store or change current vector store to another type or embed_model
delete docs in database that not existed in local folder.
it is used to delete database docs after user deleted some doc files in file browser
'''
for kb_name in list_kbs_from_folder():
folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs)
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb and kb.exists():
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
print(f"success to delete docs for file: {kb_name}/{kb_file.filename}")
kb.save_vector_store()
def prune_db_files(kb_name: str):
'''
delete files in database that not existed in local folder.
it is used to delete database files after user deleted some doc files in file browser
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
kb.save_vector_store()
return kb_files
def prune_folder_files(kb_name: str):
def prune_folder_files(kb_names: List[str]):
'''
delete doc files in local folder that not existed in database.
is is used to free local disk space by delete unused doc files.
'''
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb.exists():
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_folder) - set(files_in_db))
for file in files:
os.remove(get_file_path(kb_name, file))
return files
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb and kb.exists():
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_folder) - set(files_in_db))
for file in files:
os.remove(get_file_path(kb_name, file))
print(f"success to delete file: {kb_name}/{file}")

View File

@ -2,18 +2,17 @@ import os
from transformers import AutoTokenizer
from configs.model_config import (
from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
logger,
log_verbose,
text_splitter_dict,
llm_model_dict,
LLM_MODEL,
TEXT_SPLITTER
logger,
log_verbose,
text_splitter_dict,
LLM_MODEL,
TEXT_SPLITTER_NAME,
)
import importlib
from text_splitter import zh_title_enhance as func_zh_title_enhance
@ -23,7 +22,7 @@ from langchain.text_splitter import TextSplitter
from pathlib import Path
import json
from concurrent.futures import ThreadPoolExecutor
from server.utils import run_in_thread_pool, embedding_device
from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config
import io
from typing import List, Union, Callable, Dict, Optional, Tuple, Generator
import chardet
@ -44,8 +43,8 @@ def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_vs_path(knowledge_base_name: str, vector_name: str):
return os.path.join(get_kb_path(knowledge_base_name), vector_name)
def get_file_path(knowledge_base_name: str, doc_name: str):
@ -190,9 +189,10 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
def make_text_splitter(
splitter_name: str = TEXT_SPLITTER,
splitter_name: str = TEXT_SPLITTER_NAME,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
llm_model: str = LLM_MODEL,
):
"""
根据参数获取特定的分词器
@ -228,8 +228,9 @@ def make_text_splitter(
)
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
config = get_model_worker_config(llm_model)
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
llm_model_dict[LLM_MODEL]["local_model_path"]
config.get("model_path")
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
from transformers import GPT2TokenizerFast
@ -281,7 +282,7 @@ class KnowledgeFile:
self.docs = None
self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext)
self.text_splitter_name = TEXT_SPLITTER
self.text_splitter_name = TEXT_SPLITTER_NAME
def file2docs(self, refresh: bool=False):
if self.docs is None or refresh:
@ -372,18 +373,23 @@ def files2docs_in_thread(
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
if isinstance(file, tuple) and len(file) >= 2:
file = KnowledgeFile(filename=file[0], knowledge_base_name=file[1])
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
kwargs = file
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs["file"] = file
kwargs["chunk_size"] = chunk_size
kwargs["chunk_overlap"] = chunk_overlap
kwargs["zh_title_enhance"] = zh_title_enhance
kwargs_list.append(kwargs)
try:
if isinstance(file, tuple) and len(file) >= 2:
filename=file[0]
kb_name=file[1]
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
kwargs.update(file)
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs["file"] = file
kwargs["chunk_size"] = chunk_size
kwargs["chunk_overlap"] = chunk_overlap
kwargs["zh_title_enhance"] = zh_title_enhance
kwargs_list.append(kwargs)
except Exception as e:
yield False, (kb_name, filename, str(e))
for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool):
yield result
@ -398,4 +404,4 @@ if __name__ == "__main__":
pprint(docs[-1])
docs = kb_file.file2text()
pprint(docs[-1])
pprint(docs[-1])

View File

@ -1,10 +1,10 @@
from fastapi import Body
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
from server.utils import BaseResponse, fschat_controller_address
import httpx
from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client
def list_llm_models(
def list_running_models(
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
placeholder: str = Body(None, description="该参数未使用,占位用"),
) -> BaseResponse:
@ -13,8 +13,9 @@ def list_llm_models(
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
with get_httpx_client() as client:
r = client.post(controller_address + "/list_models")
return BaseResponse(data=r.json()["models"])
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
@ -24,6 +25,13 @@ def list_llm_models(
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
def list_config_models() -> BaseResponse:
'''
从本地获取configs中配置的模型列表
'''
return BaseResponse(data=list_llm_models())
def stop_llm_model(
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
@ -34,11 +42,12 @@ def stop_llm_model(
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
with get_httpx_client() as client:
r = client.post(
controller_address + "/release_worker",
json={"model_name": model_name},
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)
@ -57,12 +66,13 @@ def change_llm_model(
'''
try:
controller_address = controller_address or fschat_controller_address()
r = httpx.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
with get_httpx_client() as client:
r = client.post(
controller_address + "/release_worker",
json={"model_name": model_name, "new_model_name": new_model_name},
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
)
return r.json()
except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}',
exc_info=e if log_verbose else None)

View File

@ -2,3 +2,5 @@ from .zhipu import ChatGLMWorker
from .minimax import MiniMaxWorker
from .xinghuo import XingHuoWorker
from .qianfan import QianFanWorker
from .fangzhou import FangZhouWorker
from .qwen import QwenWorker

View File

@ -1,4 +1,4 @@
from configs.model_config import LOG_PATH
from configs.basic_config import LOG_PATH
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker
@ -92,5 +92,5 @@ class ApiModelWorker(BaseModelWorker):
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknow role in msg: {msg}")
raise RuntimeError(f"unknown role in msg: {msg}")
return result

View File

@ -0,0 +1,122 @@
from server.model_workers.base import ApiModelWorker
from configs.model_config import TEMPERATURE
from fastchat import conversation as conv
import sys
import json
from pprint import pprint
from server.utils import get_model_worker_config
from typing import List, Literal, Dict
def request_volc_api(
messages: List[Dict],
model_name: str = "fangzhou-api",
version: str = "chatglm-6b-model",
temperature: float = TEMPERATURE,
api_key: str = None,
secret_key: str = None,
):
from volcengine.maas import MaasService, MaasException, ChatRole
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', 'cn-beijing')
config = get_model_worker_config(model_name)
version = version or config.get("version")
version_url = config.get("version_url")
api_key = api_key or config.get("api_key")
secret_key = secret_key or config.get("secret_key")
maas.set_ak(api_key)
maas.set_sk(secret_key)
# document: "https://www.volcengine.com/docs/82379/1099475"
req = {
"model": {
"name": version,
},
"parameters": {
# 这里的参数仅为示例,具体可用的参数请参考具体模型的 API 说明
"max_new_tokens": 1000,
"temperature": temperature,
},
"messages": messages,
}
try:
resps = maas.stream_chat(req)
for resp in resps:
yield resp
except MaasException as e:
print(e)
class FangZhouWorker(ApiModelWorker):
"""
火山方舟
"""
SUPPORT_MODELS = ["chatglm-6b-model"]
def __init__(
self,
*,
version: Literal["chatglm-6b-model"] = "chatglm-6b-model",
model_names: List[str] = ["fangzhou-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384) # TODO: 不同的模型有不同的大小
super().__init__(**kwargs)
config = self.get_config()
self.version = version
self.api_key = config.get("api_key")
self.secret_key = config.get("secret_key")
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
def generate_stream_gate(self, params):
super().generate_stream_gate(params)
messages = self.prompt_to_messages(params["prompt"])
text = ""
for resp in request_volc_api(messages=messages,
model_name=self.model_names[0],
version=self.version,
temperature=params.get("temperature", TEMPERATURE),
):
error = resp.error
if error.code_n > 0:
data = {"error_code": error.code_n, "text": error.message}
elif chunk := resp.choice.message.content:
text += chunk
data = {"error_code": 0, "text": text}
yield json.dumps(data, ensure_ascii=False).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = FangZhouWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:21005",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=21005)

View File

@ -2,7 +2,7 @@ from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv
import sys
import json
import httpx
from server.utils import get_httpx_client
from pprint import pprint
from typing import List, Dict
@ -63,22 +63,23 @@ class MiniMaxWorker(ApiModelWorker):
}
print("request data sent to minimax:")
pprint(data)
response = httpx.stream("POST",
self.BASE_URL.format(pro=pro, group_id=group_id),
headers=headers,
json=data)
with response as r:
text = ""
for e in r.iter_text():
if e.startswith("data: "): # 真是优秀的返回
data = json.loads(e[6:])
if not data.get("usage"):
if choices := data.get("choices"):
chunk = choices[0].get("delta", "").strip()
if chunk:
print(chunk)
text += chunk
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
with get_httpx_client() as client:
response = client.stream("POST",
self.BASE_URL.format(pro=pro, group_id=group_id),
headers=headers,
json=data)
with response as r:
text = ""
for e in r.iter_text():
if e.startswith("data: "): # 真是优秀的返回
data = json.loads(e[6:])
if not data.get("usage"):
if choices := data.get("choices"):
chunk = choices[0].get("delta", "").strip()
if chunk:
print(chunk)
text += chunk
yield json.dumps({"error_code": 0, "text": text}, ensure_ascii=False).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
@ -93,8 +94,8 @@ if __name__ == "__main__":
worker = MiniMaxWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20004",
worker_addr="http://127.0.0.1:21002",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20003)
uvicorn.run(app, port=21002)

View File

@ -5,7 +5,7 @@ import sys
import json
import httpx
from cachetools import cached, TTLCache
from server.utils import get_model_worker_config
from server.utils import get_model_worker_config, get_httpx_client
from typing import List, Literal, Dict
@ -54,7 +54,8 @@ def get_baidu_access_token(api_key: str, secret_key: str) -> str:
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
try:
return httpx.get(url, params=params).json().get("access_token")
with get_httpx_client() as client:
return client.get(url, params=params).json().get("access_token")
except Exception as e:
print(f"failed to get token from baidu: {e}")
@ -72,7 +73,10 @@ def request_qianfan_api(
version_url = config.get("version_url")
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
if not access_token:
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
yield {
"error_code": 403,
"error_msg": f"failed to get access token. have you set the correct api_key and secret key?",
}
url = BASE_URL.format(
model_version=version_url or MODEL_VERSIONS[version],
@ -88,14 +92,15 @@ def request_qianfan_api(
'Accept': 'application/json',
}
with httpx.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
yield resp
with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=payload) as response:
for line in response.iter_lines():
if not line.strip():
continue
if line.startswith("data: "):
line = line[6:]
resp = json.loads(line)
yield resp
class QianFanWorker(ApiModelWorker):
@ -165,8 +170,8 @@ if __name__ == "__main__":
worker = QianFanWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20006",
worker_addr="http://127.0.0.1:21004"
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20006)
uvicorn.run(app, port=21004)

View File

@ -0,0 +1,123 @@
import json
import sys
from configs import TEMPERATURE
from http import HTTPStatus
from typing import List, Literal, Dict
from fastchat import conversation as conv
from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config
def request_qwen_api(
messages: List[Dict[str, str]],
api_key: str = None,
version: str = "qwen-turbo",
temperature: float = TEMPERATURE,
model_name: str = "qwen-api",
):
import dashscope
config = get_model_worker_config(model_name)
api_key = api_key or config.get("api_key")
version = version or config.get("version")
gen = dashscope.Generation()
responses = gen.call(
model=version,
temperature=temperature,
api_key=api_key,
messages=messages,
result_format='message', # set the result is message format.
stream=True,
)
text = ""
for resp in responses:
if resp.status_code != HTTPStatus.OK:
yield {
"code": resp.status_code,
"text": "api not response correctly",
}
if resp["status_code"] == 200:
if choices := resp["output"]["choices"]:
yield {
"code": 200,
"text": choices[0]["message"]["content"],
}
else:
yield {
"code": resp["status_code"],
"text": resp["message"],
}
class QwenWorker(ApiModelWorker):
def __init__(
self,
*,
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
model_names: List[str] = ["qwen-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)
# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
config = self.get_config()
self.api_key = config.get("api_key")
self.version = version
def generate_stream_gate(self, params):
messages = self.prompt_to_messages(params["prompt"])
for resp in request_qwen_api(messages=messages,
api_key=self.api_key,
version=self.version,
temperature=params.get("temperature")):
if resp["code"] == 200:
yield json.dumps({
"error_code": 0,
"text": resp["text"]
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps({
"error_code": resp["code"],
"text": resp["text"]
},
ensure_ascii=False
).encode() + b"\0"
def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = QwenWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20007)

View File

@ -94,8 +94,8 @@ if __name__ == "__main__":
worker = XingHuoWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20005",
worker_addr="http://127.0.0.1:21003",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20005)
uvicorn.run(app, port=21003)

View File

@ -67,8 +67,8 @@ if __name__ == "__main__":
worker = ChatGLMWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20003",
worker_addr="http://127.0.0.1:21001",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20003)
uvicorn.run(app, port=21001)

View File

@ -4,16 +4,57 @@ from typing import List
from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs.model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE, EMBEDDING_DEVICE, logger, log_verbose
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
logger, log_verbose,
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional, Callable, Generator, Dict, Any
from langchain.chat_models import ChatOpenAI
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
thread_pool = ThreadPoolExecutor(os.cpu_count())
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
await fn
except Exception as e:
# TODO: handle exception
msg = f"Caught exception: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
finally:
# Signal the aiter to stop.
event.set()
def get_ChatOpenAI(
model_name: str,
temperature: float,
streaming: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
**kwargs: Any,
) -> ChatOpenAI:
config = get_model_worker_config(model_name)
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
openai_proxy=config.get("openai_proxy"),
**kwargs
)
return model
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
@ -197,22 +238,71 @@ def MakeFastAPIOffline(
)
# 从model_config中获取模型信息
def list_embed_models() -> List[str]:
'''
get names of configured embedding models
'''
return list(MODEL_PATH["embed_model"])
def list_llm_models() -> Dict[str, List[str]]:
'''
get names of configured llm models with different types.
return [(model_name, config_type), ...]
'''
workers = list(FSCHAT_MODEL_WORKERS)
if "default" in workers:
workers.remove("default")
return {
"local": list(MODEL_PATH["llm_model"]),
"online": list(ONLINE_LLM_MODEL),
"worker": workers,
}
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
if type in MODEL_PATH:
paths = MODEL_PATH[type]
else:
paths = {}
for v in MODEL_PATH.values():
paths.update(v)
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
path = Path(path_str)
if path.is_dir(): # 任意绝对路径
return str(path)
root_path = Path(MODEL_ROOT_PATH)
if root_path.is_dir():
path = root_path / model_name
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
return str(path)
path = root_path / path_str
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
return str(path)
path = root_path / path_str.split("/")[-1]
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
return str(path)
return path_str # THUDM/chatglm06b
# 从server_config中获取服务信息
def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
def get_model_worker_config(model_name: str = None) -> dict:
'''
加载model worker的配置项
优先级:FSCHAT_MODEL_WORKERS[model_name] > llm_model_dict[model_name] > FSCHAT_MODEL_WORKERS["default"]
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
'''
from configs.model_config import ONLINE_LLM_MODEL
from configs.server_config import FSCHAT_MODEL_WORKERS
from server import model_workers
from configs.model_config import llm_model_dict
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {}))
config.update(ONLINE_LLM_MODEL.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
# 如果没有设置local_model_path则认为是在线模型API
if not config.get("local_model_path"):
# 在线模型API
if model_name in ONLINE_LLM_MODEL:
config["online_api"] = True
if provider := config.get("provider"):
try:
@ -222,13 +312,14 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
config["device"] = llm_device(config.get("device") or LLM_DEVICE)
config["model_path"] = get_model_path(model_name)
config["device"] = llm_device(config.get("device"))
return config
def get_all_model_worker_configs() -> dict:
result = {}
model_names = set(llm_model_dict.keys()) | set(FSCHAT_MODEL_WORKERS.keys())
model_names = set(FSCHAT_MODEL_WORKERS.keys())
for name in model_names:
if name != "default":
result[name] = get_model_worker_config(name)
@ -256,7 +347,7 @@ def fschat_openai_api_address() -> str:
host = FSCHAT_OPENAI_API["host"]
port = FSCHAT_OPENAI_API["port"]
return f"http://{host}:{port}"
return f"http://{host}:{port}/v1"
def api_address() -> str:
@ -275,19 +366,74 @@ def webui_address() -> str:
return f"http://{host}:{port}"
def set_httpx_timeout(timeout: float = None):
def get_prompt_template(name: str) -> Optional[str]:
'''
设置httpx默认timeout
httpx默认timeout是5秒在请求LLM回答时不够用
从prompt_config中加载模板内容
'''
from configs import prompt_config
import importlib
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
return prompt_config.PROMPT_TEMPLATES.get(name)
def set_httpx_config(
timeout: float = HTTPX_DEFAULT_TIMEOUT,
proxy: Union[str, Dict] = None,
):
'''
设置httpx默认timeouthttpx默认timeout是5秒在请求LLM回答时不够用
将本项目相关服务加入无代理列表避免fastchat的服务器请求错误(windows下无效)
对于chatgpt等在线API如要使用代理需要手动配置搜索引擎的代理如何处置还需考虑
'''
import httpx
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import os
timeout = timeout or HTTPX_DEFAULT_TIMEOUT
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
# 在进程范围内设置系统级代理
proxies = {}
if isinstance(proxy, str):
for n in ["http", "https", "all"]:
proxies[n + "_proxy"] = proxy
elif isinstance(proxy, dict):
for n in ["http", "https", "all"]:
if p:= proxy.get(n):
proxies[n + "_proxy"] = p
elif p:= proxy.get(n + "_proxy"):
proxies[n + "_proxy"] = p
for k, v in proxies.items():
os.environ[k] = v
# set host to bypass proxy
no_proxy = [x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()]
no_proxy += [
# do not use proxy for locahost
"http://127.0.0.1",
"http://localhost",
]
# do not use proxy for user deployed fastchat servers
for x in [
fschat_controller_address(),
fschat_model_worker_address(),
fschat_openai_api_address(),
]:
host = ":".join(x.split(":")[:2])
if host not in no_proxy:
no_proxy.append(host)
os.environ["NO_PROXY"] = ",".join(no_proxy)
# TODO: 简单的清除系统代理不是个好的选择影响太多。似乎修改代理服务器的bypass列表更好。
# patch requests to use custom proxies instead of system settings
# def _get_proxies():
# return {}
# import urllib.request
# urllib.request.getproxies = _get_proxies
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
@ -302,13 +448,15 @@ def detect_device() -> Literal["cuda", "mps", "cpu"]:
return "cpu"
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or LLM_DEVICE
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
device = device or EMBEDDING_DEVICE
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
@ -333,3 +481,51 @@ def run_in_thread_pool(
for obj in as_completed(tasks):
yield obj.result()
def get_httpx_client(
use_async: bool = False,
proxies: Union[str, Dict] = None,
timeout: float = HTTPX_DEFAULT_TIMEOUT,
**kwargs,
) -> Union[httpx.Client, httpx.AsyncClient]:
'''
helper to get httpx client with default proxies that bypass local addesses.
'''
default_proxies = {
# do not use proxy for locahost
"all://127.0.0.1": None,
"all://localhost": None,
}
# do not use proxy for user deployed fastchat servers
for x in [
fschat_controller_address(),
fschat_model_worker_address(),
fschat_openai_api_address(),
]:
host = ":".join(x.split(":")[:2])
default_proxies.update({host: None})
# get proxies from system envionrent
default_proxies.update({
"http://": os.environ.get("http_proxy"),
"https://": os.environ.get("https_proxy"),
"all://": os.environ.get("all_proxy"),
})
for host in os.environ.get("no_proxy", "").split(","):
if host := host.strip():
default_proxies.update({host: None})
# merge default proxies with user provided proxies
if isinstance(proxies, str):
proxies = {"all://": proxies}
if isinstance(proxies, dict):
default_proxies.update(proxies)
# construct Client
kwargs.update(timeout=timeout, proxies=default_proxies)
if use_async:
return httpx.AsyncClient(**kwargs)
else:
return httpx.Client(**kwargs)

View File

@ -17,12 +17,21 @@ except:
pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger, log_verbose, TEXT_SPLITTER
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
from configs import (
LOG_PATH,
log_verbose,
logger,
LLM_MODEL,
EMBEDDING_MODEL,
TEXT_SPLITTER_NAME,
FSCHAT_CONTROLLER,
FSCHAT_OPENAI_API,
API_SERVER,
WEBUI_SERVER,
HTTPX_DEFAULT_TIMEOUT,
)
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout,
fschat_openai_api_address, set_httpx_config, get_httpx_client,
get_model_worker_config, get_all_model_worker_configs,
MakeFastAPIOffline, FastAPI, llm_device, embedding_device)
import argparse
@ -49,112 +58,162 @@ def create_controller_app(
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
"""
kwargs包含的字段如下
host:
port:
model_names:[`model_name`]
controller_address:
worker_address:
对于online_api:
online_api:True
worker_class: `provider`
对于离线模型
model_path: `model_name_or_path`,huggingface的repo-id或本地路径
device:`LLM_DEVICE`
"""
import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
from fastchat.serve.model_worker import worker_id, logger
import argparse
import threading
import fastchat.serve.model_worker
logger.setLevel(log_level)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
ModelWorker.init_heart_beat = _new_init_heart_beat
parser = argparse.ArgumentParser()
args = parser.parse_args([])
# default args. should be deleted after pr is merged by fastchat
args.gpus = None
args.max_gpu_memory = "20GiB"
args.load_8bit = False
args.cpu_offloading = None
args.gptq_ckpt = None
args.gptq_wbits = 16
args.gptq_groupsize = -1
args.gptq_act_order = False
args.awq_ckpt = None
args.awq_wbits = 16
args.awq_groupsize = -1
args.num_gpus = 1
args.model_names = []
args.conv_template = None
args.limit_worker_concurrency = 5
args.stream_interval = 2
args.no_register = False
args.embed_in_truncate = False
for k, v in kwargs.items():
setattr(args, k, v)
if args.gpus:
if args.num_gpus is None:
args.num_gpus = len(args.gpus.split(','))
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
# 在线模型API
if worker_class := kwargs.get("worker_class"):
from fastchat.serve.model_worker import app
worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address,
worker_addr=args.worker_address)
sys.modules["fastchat.serve.model_worker"].worker = worker
# 本地模型
else:
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def _new_init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
from configs.model_config import VLLM_MODEL_DICT
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
import fastchat.serve.vllm_worker
from fastchat.serve.vllm_worker import VLLMWorker,app
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer_mode = 'auto'
args.trust_remote_code= True
args.download_dir= None
args.load_format = 'auto'
args.dtype = 'auto'
args.seed = 0
args.worker_use_ray = False
args.pipeline_parallel_size = 1
args.tensor_parallel_size = 1
args.block_size = 16
args.swap_space = 4 # GiB
args.gpu_memory_utilization = 0.90
args.max_num_batched_tokens = 2560
args.max_num_seqs = 256
args.disable_log_stats = False
args.conv_template = None
args.limit_worker_concurrency = 5
args.no_register = False
args.num_gpus = 1 # vllm worker的切分是tensor并行这里填写显卡的数量
args.engine_use_ray = False
args.disable_log_requests = False
if args.model_path:
args.model = args.model_path
if args.num_gpus > 1:
args.tensor_parallel_size = args.num_gpus
for k, v in kwargs.items():
setattr(args, k, v)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
worker = VLLMWorker(
controller_addr = args.controller_address,
worker_addr = args.worker_address,
worker_id = worker_id,
model_path = args.model_path,
model_names = args.model_names,
limit_worker_concurrency = args.limit_worker_concurrency,
no_register = args.no_register,
llm_engine = engine,
conv_template = args.conv_template,
)
sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker
else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker
args.gpus = "0" # GPU的编号,如果有多个GPU可以设置为"0,1,2,3"
args.max_gpu_memory = "20GiB"
args.num_gpus = 1 # model worker的切分是model并行这里填写显卡的数量
args.load_8bit = False
args.cpu_offloading = None
args.gptq_ckpt = None
args.gptq_wbits = 16
args.gptq_groupsize = -1
args.gptq_act_order = False
args.awq_ckpt = None
args.awq_wbits = 16
args.awq_groupsize = -1
args.model_names = []
args.conv_template = None
args.limit_worker_concurrency = 5
args.stream_interval = 2
args.no_register = False
args.embed_in_truncate = False
for k, v in kwargs.items():
setattr(args, k, v)
if args.gpus:
if args.num_gpus is None:
args.num_gpus = len(args.gpus.split(','))
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
self.heart_beat_thread.start()
ModelWorker.init_heart_beat = _new_init_heart_beat
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
worker = ModelWorker(
controller_addr=args.controller_address,
worker_addr=args.worker_address,
worker_id=worker_id,
model_path=args.model_path,
model_names=args.model_names,
limit_worker_concurrency=args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
)
sys.modules["fastchat.serve.model_worker"].args = args
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
sys.modules["fastchat.serve.model_worker"].worker = worker
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
app.title = f"FastChat LLM Server ({args.model_names[0]})"
@ -194,7 +253,6 @@ def create_openai_api_app(
def _set_app_event(app: FastAPI, started_event: mp.Event = None):
@app.on_event("startup")
async def on_startup():
set_httpx_timeout()
if started_event is not None:
started_event.set()
@ -205,6 +263,8 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
from fastapi import Body
import time
import sys
from server.utils import set_httpx_config
set_httpx_config()
app = create_controller_app(
dispatch_method=FSCHAT_CONTROLLER.get("dispatch_method"),
@ -216,7 +276,7 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
@app.post("/release_worker")
def release_worker(
model_name: str = Body(..., description="要释放模型的名称", samples=["chatglm-6b"]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[fschat_controller_address()]),
# worker_address: str = Body(None, description="要释放模型的地址,与名称二选一", samples=[FSCHAT_CONTROLLER_address()]),
new_model_name: str = Body(None, description="释放后加载该模型"),
keep_origin: bool = Body(False, description="不释放原模型,加载新模型")
) -> Dict:
@ -242,15 +302,16 @@ def run_controller(log_level: str = "INFO", started_event: mp.Event = None):
logger.error(msg)
return {"code": 500, "msg": msg}
r = httpx.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
with get_httpx_client() as client:
r = client.post(worker_address + "/release",
json={"new_model_name": new_model_name, "keep_origin": keep_origin})
if r.status_code != 200:
msg = f"failed to release model: {model_name}"
logger.error(msg)
return {"code": 500, "msg": msg}
if new_model_name:
timer = HTTPX_DEFAULT_TIMEOUT * 2 # wait for new model_worker register
timer = HTTPX_DEFAULT_TIMEOUT # wait for new model_worker register
while timer > 0:
models = app._controller.list_models()
if new_model_name in models:
@ -290,6 +351,8 @@ def run_model_worker(
import uvicorn
from fastapi import Body
import sys
from server.utils import set_httpx_config
set_httpx_config()
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
@ -297,7 +360,7 @@ def run_model_worker(
kwargs["model_names"] = [model_name]
kwargs["controller_address"] = controller_address or fschat_controller_address()
kwargs["worker_address"] = fschat_model_worker_address(model_name)
model_path = kwargs.get("local_model_path", "")
model_path = kwargs.get("model_path", "")
kwargs["model_path"] = model_path
app = create_model_worker_app(log_level=log_level, **kwargs)
@ -328,6 +391,8 @@ def run_model_worker(
def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
import uvicorn
import sys
from server.utils import set_httpx_config
set_httpx_config()
controller_addr = fschat_controller_address()
app = create_openai_api_app(controller_addr, log_level=log_level) # TODO: not support keys yet.
@ -344,6 +409,8 @@ def run_openai_api(log_level: str = "INFO", started_event: mp.Event = None):
def run_api_server(started_event: mp.Event = None):
from server.api import create_app
import uvicorn
from server.utils import set_httpx_config
set_httpx_config()
app = create_app()
_set_app_event(app, started_event)
@ -355,6 +422,9 @@ def run_api_server(started_event: mp.Event = None):
def run_webui(started_event: mp.Event = None):
from server.utils import set_httpx_config
set_httpx_config()
host = WEBUI_SERVER["host"]
port = WEBUI_SERVER["port"]
@ -418,7 +488,7 @@ def parse_args() -> argparse.ArgumentParser:
"-c",
"--controller",
type=str,
help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER",
help="specify controller address the worker is registered to. default is FSCHAT_CONTROLLER",
dest="controller_address",
)
parser.add_argument(
@ -470,19 +540,18 @@ def dump_server_info(after_start=False, args=None):
if args and args.model_name:
models = args.model_name
print(f"当前使用的分词器:{TEXT_SPLITTER}")
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
print(f"当前启动的LLM模型{models} @ {llm_device()}")
for model in models:
pprint(llm_model_dict[model])
pprint(get_model_worker_config(model))
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
print("\n")
print(f"服务端运行信息:")
if args.openai_api:
print(f" OpenAI API Server: {fschat_openai_api_address()}/v1")
print(" (请确认llm_model_dict中配置的api_base_url与上面地址一致。)")
print(f" OpenAI API Server: {fschat_openai_api_address()}")
if args.api:
print(f" Chatchat API Server: {api_address()}")
if args.webui:

View File

@ -0,0 +1,40 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from configs import LLM_MODEL, TEMPERATURE
from server.utils import get_ChatOpenAI
from langchain.chains import LLMChain
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.tools import tools, tool_names
from langchain.memory import ConversationBufferWindowMemory
memory = ConversationBufferWindowMemory(k=5)
model = get_ChatOpenAI(
model_name=LLM_MODEL,
temperature=TEMPERATURE,
)
from server.agent.custom_template import CustomOutputParser, prompt
output_parser = CustomOutputParser()
llm_chain = LLMChain(llm=model, prompt=prompt)
agent = LLMSingleActionAgent(
llm_chain=llm_chain,
output_parser=output_parser,
stop=["\nObservation:"],
allowed_tools=tool_names
)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, memory=memory, verbose=True)
import pytest
@pytest.mark.parametrize("text_prompt",
["北京市朝阳区未来24小时天气如何", # 天气功能函数
"计算 (2 + 2312312)/4 是多少?", # 计算功能函数
"翻译这句话成中文Life is the art of drawing sufficient conclusions form insufficient premises."] # 翻译功能函数
)
def test_different_agent_function(text_prompt):
try:
text_answer = agent_executor.run(text_prompt)
assert text_answer is not None
except Exception as e:
pytest.fail(f"agent_function failed with {text_prompt}, error: {str(e)}")

View File

@ -6,7 +6,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from configs import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from pprint import pprint

View File

@ -6,7 +6,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from configs import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest

View File

@ -6,21 +6,19 @@ from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from configs.server_config import FSCHAT_MODEL_WORKERS
from configs.model_config import LLM_MODEL, llm_model_dict
from configs.model_config import LLM_MODEL
from server.utils import api_address, get_model_worker_config
from pprint import pprint
import random
from typing import List
def get_configured_models():
def get_configured_models() -> List[str]:
model_workers = list(FSCHAT_MODEL_WORKERS)
if "default" in model_workers:
model_workers.remove("default")
llm_dict = list(llm_model_dict)
return model_workers, llm_dict
return model_workers
api_base_url = api_address()
@ -56,12 +54,9 @@ def test_change_model(api="/llm_model/change"):
running_models = get_running_models()
assert len(running_models) > 0
model_workers, llm_dict = get_configured_models()
model_workers = get_configured_models()
availabel_new_models = set(model_workers) - set(running_models)
if len(availabel_new_models) == 0:
availabel_new_models = set(llm_dict) - set(running_models)
availabel_new_models = list(availabel_new_models)
availabel_new_models = list(set(model_workers) - set(running_models))
assert len(availabel_new_models) > 0
print(availabel_new_models)

View File

@ -4,7 +4,7 @@ import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.model_config import BING_SUBSCRIPTION_KEY
from configs import BING_SUBSCRIPTION_KEY
from server.utils import api_address
from pprint import pprint
@ -91,7 +91,7 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
print("=" * 30 + api + " output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if "anser" in data:
if "answer" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
@ -114,7 +114,7 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
assert data["msg"] == f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`"
print("\n")
print("=" * 30 + api + " by {se} output" + "="*30)
print("=" * 30 + api + f" by {se} output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if "answer" in data:

View File

@ -4,7 +4,7 @@ from transformers import AutoTokenizer
import sys
sys.path.append("../..")
from configs.model_config import (
from configs import (
CHUNK_SIZE,
OVERLAP_SIZE
)

View File

@ -0,0 +1,22 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.fangzhou import request_volc_api
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", ["chatglm-6b-model"])
def test_qianfan(version):
messages = [{"role": "user", "content": "hello"}]
print("\n" + version + "\n")
i = 1
for x in request_volc_api(messages, version=version):
print(type(x))
pprint(x)
if chunk := x.choice.message.content:
print(chunk)
assert x.choice.message
i += 1

View File

@ -8,7 +8,7 @@ from pprint import pprint
import pytest
@pytest.mark.parametrize("version", MODEL_VERSIONS.keys())
@pytest.mark.parametrize("version", list(MODEL_VERSIONS.keys())[:2])
def test_qianfan(version):
messages = [{"role": "user", "content": "你好"}]
print("\n" + version + "\n")

View File

@ -0,0 +1,19 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.model_workers.qwen import request_qwen_api
from pprint import pprint
import pytest
@pytest.mark.parametrize("version", ["qwen-turbo"])
def test_qwen(version):
messages = [{"role": "user", "content": "hello"}]
print("\n" + version + "\n")
for x in request_qwen_api(messages, version=version):
print(type(x))
pprint(x)
assert x["code"] == 200

139
tests/test_migrate.py Normal file
View File

@ -0,0 +1,139 @@
from pathlib import Path
from pprint import pprint
import os
import shutil
import sys
root_path = Path(__file__).parent.parent
sys.path.append(str(root_path))
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.utils import get_kb_path, get_doc_path, KnowledgeFile
from server.knowledge_base.migrate import folder2db, prune_db_docs, prune_folder_files
# setup test knowledge base
kb_name = "test_kb_for_migrate"
test_files = {
"faq.md": str(root_path / "docs" / "faq.md"),
"install.md": str(root_path / "docs" / "install.md"),
}
kb_path = get_kb_path(kb_name)
doc_path = get_doc_path(kb_name)
if not os.path.isdir(doc_path):
os.makedirs(doc_path)
for k, v in test_files.items():
shutil.copy(v, os.path.join(doc_path, k))
def test_recreate_vs():
folder2db([kb_name], "recreate_vs")
kb = KBServiceFactory.get_service_by_name(kb_name)
assert kb.exists()
files = kb.list_files()
print(files)
for name in test_files:
assert name in files
path = os.path.join(doc_path, name)
# list docs based on file name
docs = kb.list_docs(file_name=name)
assert len(docs) > 0
pprint(docs[0])
for doc in docs:
assert doc.metadata["source"] == path
# list docs base on metadata
docs = kb.list_docs(metadata={"source": path})
assert len(docs) > 0
for doc in docs:
assert doc.metadata["source"] == path
def test_increament():
kb = KBServiceFactory.get_service_by_name(kb_name)
kb.clear_vs()
assert kb.list_files() == []
assert kb.list_docs() == []
folder2db([kb_name], "increament")
files = kb.list_files()
print(files)
for f in test_files:
assert f in files
docs = kb.list_docs(file_name=f)
assert len(docs) > 0
pprint(docs[0])
for doc in docs:
assert doc.metadata["source"] == os.path.join(doc_path, f)
def test_prune_db():
del_file, keep_file = list(test_files)[:2]
os.remove(os.path.join(doc_path, del_file))
prune_db_docs([kb_name])
kb = KBServiceFactory.get_service_by_name(kb_name)
files = kb.list_files()
print(files)
assert del_file not in files
assert keep_file in files
docs = kb.list_docs(file_name=del_file)
assert len(docs) == 0
docs = kb.list_docs(file_name=keep_file)
assert len(docs) > 0
pprint(docs[0])
shutil.copy(test_files[del_file], os.path.join(doc_path, del_file))
def test_prune_folder():
del_file, keep_file = list(test_files)[:2]
kb = KBServiceFactory.get_service_by_name(kb_name)
# delete docs for file
kb.delete_doc(KnowledgeFile(del_file, kb_name))
files = kb.list_files()
print(files)
assert del_file not in files
assert keep_file in files
docs = kb.list_docs(file_name=del_file)
assert len(docs) == 0
docs = kb.list_docs(file_name=keep_file)
assert len(docs) > 0
docs = kb.list_docs(file_name=del_file)
assert len(docs) == 0
assert os.path.isfile(os.path.join(doc_path, del_file))
# prune folder
prune_folder_files([kb_name])
# check result
assert not os.path.isfile(os.path.join(doc_path, del_file))
assert os.path.isfile(os.path.join(doc_path, keep_file))
def test_drop_kb():
kb = KBServiceFactory.get_service_by_name(kb_name)
kb.drop_kb()
assert not kb.exists()
assert not os.path.isdir(kb_path)
kb = KBServiceFactory.get_service_by_name(kb_name)
assert kb is None

View File

@ -1,15 +1,13 @@
import streamlit as st
from configs.server_config import FSCHAT_MODEL_WORKERS
from webui_pages.utils import *
from streamlit_chatbox import *
from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
import os
from configs.model_config import LLM_MODEL, TEMPERATURE
from configs import LLM_MODEL, TEMPERATURE
from server.utils import get_model_worker_config
from typing import List, Dict
chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
@ -18,30 +16,24 @@ chat_box = ChatBox(
)
def get_messages_history(history_len: int) -> List[Dict]:
def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]:
'''
返回消息历史
content_in_expander控制是否返回expander元素中的内容一般导出的时候可以选上传入LLM的history不需要
'''
def filter(msg):
'''
针对当前简单文本对话只返回每条消息的第一个element的内容
'''
content = [x._content for x in msg["elements"] if x._output_method in ["markdown", "text"]]
content = [x for x in msg["elements"] if x._output_method in ["markdown", "text"]]
if not content_in_expander:
content = [x for x in content if not x._in_expander]
content = [x.content for x in content]
return {
"role": msg["role"],
"content": content[0] if content else "",
"content": "\n\n".join(content),
}
# workaround before upgrading streamlit-chatbox.
def stop(h):
return False
history = chat_box.filter_history(history_len=100000, filter=filter, stop=stop)
user_count = 0
i = 1
for i in range(1, len(history) + 1):
if history[-i]["role"] == "user":
user_count += 1
if user_count >= history_len:
break
return history[-i:]
return chat_box.filter_history(history_len=history_len, filter=filter)
def dialogue_page(api: ApiRequest):
@ -63,6 +55,7 @@ def dialogue_page(api: ApiRequest):
["LLM 对话",
"知识库问答",
"搜索引擎问答",
"自定义Agent问答",
],
index=1,
on_change=on_mode_change,
@ -71,8 +64,9 @@ def dialogue_page(api: ApiRequest):
def on_llm_change():
config = get_model_worker_config(llm_model)
if not config.get("online_api"): # 只有本地model_worker可以切换模型
if not config.get("online_api"): # 只有本地model_worker可以切换模型
st.session_state["prev_llm_model"] = llm_model
st.session_state["cur_llm_model"] = st.session_state.llm_model
def llm_model_format_func(x):
if x in running_models:
@ -80,25 +74,32 @@ def dialogue_page(api: ApiRequest):
return x
running_models = api.list_running_models()
available_models = []
config_models = api.list_config_models()
for x in running_models:
if x in config_models:
config_models.remove(x)
llm_models = running_models + config_models
cur_model = st.session_state.get("cur_llm_model", LLM_MODEL)
index = llm_models.index(cur_model)
for models in config_models.values():
for m in models:
if m not in running_models:
available_models.append(m)
llm_models = running_models + available_models
index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL))
llm_model = st.selectbox("选择LLM模型",
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
# key="llm_model",
)
llm_models,
index,
format_func=llm_model_format_func,
on_change=on_llm_change,
key="llm_model",
)
if (st.session_state.get("prev_llm_model") != llm_model
and not get_model_worker_config(llm_model).get("online_api")):
and not get_model_worker_config(llm_model).get("online_api")
and llm_model not in running_models):
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["cur_llm_model"] = llm_model
prev_model = st.session_state.get("prev_llm_model")
r = api.change_llm_model(prev_model, llm_model)
if msg := check_error_msg(r):
st.error(msg)
elif msg := check_success_msg(r):
st.success(msg)
st.session_state["prev_llm_model"] = llm_model
temperature = st.slider("Temperature", 0.0, 1.0, TEMPERATURE, 0.05)
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
@ -143,17 +144,42 @@ def dialogue_page(api: ApiRequest):
text = ""
r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
break
text += t
chat_box.update_msg(text)
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
elif dialogue_mode == "自定义Agent问答":
chat_box.ai_say([
f"正在思考和寻找工具 ...",])
text = ""
element_index = 0
for d in api.agent_chat(prompt,
history=history,
model=llm_model,
temperature=temperature):
try:
d = json.loads(d)
except:
pass
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
text += chunk
chat_box.update_msg(text, element_index=0)
elif chunk := d.get("tools"):
element_index += 1
chat_box.insert_msg(Markdown("...", in_expander=True, title="使用工具...", state="complete"))
chat_box.update_msg("\n\n".join(d.get("tools", [])), element_index=element_index, streaming=False)
chat_box.update_msg(text, element_index=0, streaming=False)
elif dialogue_mode == "知识库问答":
history = get_messages_history(history_len)
chat_box.ai_say([
f"正在查询知识库 `{selected_kb}` ...",
Markdown("...", in_expander=True, title="知识库匹配结果"),
Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
])
text = ""
for d in api.knowledge_base_chat(prompt,
@ -173,12 +199,13 @@ def dialogue_page(api: ApiRequest):
elif dialogue_mode == "搜索引擎问答":
chat_box.ai_say([
f"正在执行 `{search_engine}` 搜索...",
Markdown("...", in_expander=True, title="网络搜索结果"),
Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
])
text = ""
for d in api.search_engine_chat(prompt,
search_engine_name=search_engine,
top_k=se_top_k,
history=history,
model=llm_model,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured

View File

@ -6,9 +6,10 @@ import pandas as pd
from server.knowledge_base.utils import get_file_path, LOADER_DICT
from server.knowledge_base.kb_service.base import get_kb_details, get_kb_file_details
from typing import Literal, Dict, Tuple
from configs.model_config import (embedding_model_dict, kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from configs import (kbs_config,
EMBEDDING_MODEL, DEFAULT_VS_TYPE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import list_embed_models
import os
import time
@ -94,7 +95,7 @@ def knowledge_base_page(api: ApiRequest):
key="vs_type",
)
embed_models = list(embedding_model_dict.keys())
embed_models = list_embed_models()
embed_model = cols[1].selectbox(
"Embedding 模型",

View File

@ -1,12 +1,11 @@
# 该文件包含webui通用工具可以被不同的webui使用
from typing import *
from pathlib import Path
from configs.model_config import (
from configs import (
EMBEDDING_MODEL,
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
llm_model_dict,
HISTORY_LEN,
TEMPERATURE,
SCORE_THRESHOLD,
@ -15,9 +14,10 @@ from configs.model_config import (
ZH_TITLE_ENHANCE,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
FSCHAT_MODEL_WORKERS,
HTTPX_DEFAULT_TIMEOUT,
logger, log_verbose,
)
from configs.server_config import HTTPX_DEFAULT_TIMEOUT
import httpx
import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn
@ -26,7 +26,7 @@ import contextlib
import json
import os
from io import BytesIO
from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address
from server.utils import run_async, iter_over_async, set_httpx_config, api_address, get_httpx_client
from configs.model_config import NLTK_DATA_PATH
import nltk
@ -35,7 +35,7 @@ from pprint import pprint
KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_timeout()
set_httpx_config()
class ApiRequest:
@ -53,6 +53,8 @@ class ApiRequest:
self.base_url = base_url
self.timeout = timeout
self.no_remote_api = no_remote_api
self._client = get_httpx_client()
self._aclient = get_httpx_client(use_async=True)
if no_remote_api:
logger.warn("将来可能取消对no_remote_api的支持更新版本时请注意。")
@ -79,9 +81,9 @@ class ApiRequest:
while retry > 0:
try:
if stream:
return httpx.stream("GET", url, params=params, **kwargs)
return self._client.stream("GET", url, params=params, **kwargs)
else:
return httpx.get(url, params=params, **kwargs)
return self._client.get(url, params=params, **kwargs)
except Exception as e:
msg = f"error when get {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
@ -98,18 +100,18 @@ class ApiRequest:
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0:
try:
if stream:
return await client.stream("GET", url, params=params, **kwargs)
else:
return await client.get(url, params=params, **kwargs)
except Exception as e:
msg = f"error when aget {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
while retry > 0:
try:
if stream:
return await self._aclient.stream("GET", url, params=params, **kwargs)
else:
return await self._aclient.get(url, params=params, **kwargs)
except Exception as e:
msg = f"error when aget {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def post(
self,
@ -124,11 +126,10 @@ class ApiRequest:
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
# return requests.post(url, data=data, json=json, stream=stream, **kwargs)
if stream:
return httpx.stream("POST", url, data=data, json=json, **kwargs)
return self._client.stream("POST", url, data=data, json=json, **kwargs)
else:
return httpx.post(url, data=data, json=json, **kwargs)
return self._client.post(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when post {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
@ -146,18 +147,18 @@ class ApiRequest:
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0:
try:
if stream:
return await client.stream("POST", url, data=data, json=json, **kwargs)
else:
return await client.post(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when apost {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
while retry > 0:
try:
if stream:
return await self._client.stream("POST", url, data=data, json=json, **kwargs)
else:
return await self._client.post(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when apost {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def delete(
self,
@ -173,9 +174,9 @@ class ApiRequest:
while retry > 0:
try:
if stream:
return httpx.stream("DELETE", url, data=data, json=json, **kwargs)
return self._client.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return httpx.delete(url, data=data, json=json, **kwargs)
return self._client.delete(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when delete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
@ -193,18 +194,18 @@ class ApiRequest:
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0:
try:
if stream:
return await client.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return await client.delete(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when adelete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
while retry > 0:
try:
if stream:
return await self._aclient.stream("DELETE", url, data=data, json=json, **kwargs)
else:
return await self._aclient.delete(url, data=data, json=json, **kwargs)
except Exception as e:
msg = f"error when adelete {url}: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
retry -= 1
def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False):
'''
@ -315,6 +316,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
prompt_name: str = "llm_chat",
no_remote_api: bool = None,
):
'''
@ -323,6 +325,41 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"query": query,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"prompt_name": prompt_name,
}
print(f"received input message:")
pprint(data)
if no_remote_api:
from server.chat.chat import chat
response = run_async(chat(**data))
return self._fastapi_stream2generator(response)
else:
response = self.post("/chat/chat", json=data, stream=True)
return self._httpx_stream2generator(response)
def agent_chat(
self,
query: str,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
no_remote_api: bool = None,
):
'''
对应api.py/chat/agent_chat 接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"query": query,
"history": history,
@ -335,11 +372,11 @@ class ApiRequest:
pprint(data)
if no_remote_api:
from server.chat.chat import chat
response = run_async(chat(**data))
from server.chat.agent_chat import agent_chat
response = run_async(agent_chat(**data))
return self._fastapi_stream2generator(response)
else:
response = self.post("/chat/chat", json=data, stream=True)
response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response)
def knowledge_base_chat(
@ -352,6 +389,7 @@ class ApiRequest:
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None,
):
'''
@ -370,6 +408,7 @@ class ApiRequest:
"model_name": model,
"temperature": temperature,
"local_doc_url": no_remote_api,
"prompt_name": prompt_name,
}
print(f"received input message:")
@ -392,9 +431,11 @@ class ApiRequest:
query: str,
search_engine_name: str,
top_k: int = SEARCH_ENGINE_TOP_K,
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
prompt_name: str = "knowledge_base_chat",
no_remote_api: bool = None,
):
'''
@ -407,9 +448,11 @@ class ApiRequest:
"query": query,
"search_engine_name": search_engine_name,
"top_k": top_k,
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"prompt_name": prompt_name,
}
print(f"received input message:")
@ -766,20 +809,31 @@ class ApiRequest:
"controller_address": controller_address,
}
if no_remote_api:
from server.llm_api import list_llm_models
return list_llm_models(**data).data
from server.llm_api import list_running_models
return list_running_models(**data).data
else:
r = self.post(
"/llm_model/list_models",
"/llm_model/list_running_models",
json=data,
)
return r.json().get("data", [])
def list_config_models(self):
def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]:
'''
获取configs中配置的模型列表
获取configs中配置的模型列表返回形式为{"type": [model_name1, model_name2, ...], ...}
如果no_remote_api=True, 从运行ApiRequest的机器上获取否则从运行api.py的机器上获取
'''
return list(llm_model_dict.keys())
if no_remote_api is None:
no_remote_api = self.no_remote_api
if no_remote_api:
from server.llm_api import list_config_models
return list_config_models().data
else:
r = self.post(
"/llm_model/list_config_models",
)
return r.json().get("data", {})
def stop_llm_model(
self,
@ -825,13 +879,13 @@ class ApiRequest:
if not model_name or not new_model_name:
return
if new_model_name == model_name:
running_models = self.list_running_models()
if new_model_name == model_name or new_model_name in running_models:
return {
"code": 200,
"msg": "什么都不用做"
"msg": "无需切换"
}
running_models = self.list_running_models()
if model_name not in running_models:
return {
"code": 500,
@ -839,7 +893,7 @@ class ApiRequest:
}
config_models = self.list_config_models()
if new_model_name not in config_models:
if new_model_name not in config_models.get("local", []):
return {
"code": 500,
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"